diff --git a/.github/workflows/docker_security.yml b/.github/workflows/docker_security.yml index 124b1f78a..43ec8c6f9 100644 --- a/.github/workflows/docker_security.yml +++ b/.github/workflows/docker_security.yml @@ -4,11 +4,14 @@ permissions: contents: read security-events: write pull-requests: write + actions: read on: push: branches: - master + pull_request: + branches: [master] schedule: - cron: "0 0 * * *" # Runs daily at midnight UTC diff --git a/.github/workflows/pip_audit.yml b/.github/workflows/pip_audit.yml index 4076efbcd..963e63ba1 100644 --- a/.github/workflows/pip_audit.yml +++ b/.github/workflows/pip_audit.yml @@ -9,15 +9,13 @@ on: - cron: '00 00 * * *' push: branches: [master] - pull_request: - branches: [master] concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} cancel-in-progress: true jobs: - build: + audit: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 @@ -28,317 +26,34 @@ jobs: enable-cache: true version: "0.7.12" - - name: Configure Git - run: | - git config --global user.name 'github-actions[bot]' - git config --global user.email 'github-actions[bot]@users.noreply.github.com' - - name: Audit dependencies and identify vulnerabilities - id: audit run: | # Export requirements for pip-audit to analyze uv export --all-extras --format requirements-txt --no-emit-project > requirements.txt # Run pip-audit but don't fail if vulnerabilities are found - uvx pip-audit -r requirements.txt --disable-pip -v > pip_audit_results.txt || true - - # Check if vulnerabilities were found - if [ ! -s pip_audit_results.txt ]; then - echo "has_vulnerabilities=false" >> $GITHUB_OUTPUT - else - echo "has_vulnerabilities=true" >> $GITHUB_OUTPUT - - # Create a detailed mapping of all vulnerabilities for later use - { - # Add a header row for the CSV format - echo "pkg_name,current_ver,vuln_id,fixed_ver" - - # Extract all vulnerabilities with their details - grep -v "^Name\|^------" pip_audit_results.txt | while read -r line; do - if [[ -n "$line" ]]; then - # Extract fields: package name, current version, vulnerability ID, fixed version - pkg_name=$(echo "$line" | awk '{print $1}') - current_ver=$(echo "$line" | awk '{print $2}') - vuln_id=$(echo "$line" | awk '{print $3}') - fixed_ver=$(echo "$line" | awk '{print $NF}') - - # Output as CSV - echo "$pkg_name,$current_ver,$vuln_id,$fixed_ver" - fi - done - } > all_vulnerabilities.csv - - # Store all_vulnerabilities.csv as an artifact - echo "all_vulns_data<> $GITHUB_OUTPUT - cat all_vulnerabilities.csv >> $GITHUB_OUTPUT - echo "EOF" >> $GITHUB_OUTPUT - - # Get unique packages with their highest fixed version - { - echo "Processing unique packages with highest fixed versions:" - - # Use awk to process the CSV and find highest versions - awk -F, 'BEGIN {OFS=","} - # Custom function for semantic version comparison - function version_gt(v1, v2) { - n1 = split(v1, a, "[.-]") - n2 = split(v2, b, "[.-]") - - # Compare each version component - for (i = 1; i <= n1 && i <= n2; i++) { - if (a[i] == b[i]) continue - return (a[i]+0) > (b[i]+0) - } - return n1 > n2 - } - NR == 1 {next} # Skip header - { - pkg = $1 - curr_ver = $2 - vuln = $3 - fix_ver = $4 - - print "Found=" pkg, "current=" curr_ver, "vuln=" vuln, "fix=" fix_ver - - # Check if we have seen this package before - if (!(pkg in highest_ver) || version_gt(fix_ver, highest_ver[pkg])) { - highest_ver[pkg] = fix_ver - print " Updated highest version for", pkg, "to", fix_ver - } - } - END { - # Output unique packages with highest versions - for (pkg in highest_ver) { - print pkg "==" highest_ver[pkg] - } - }' all_vulnerabilities.csv - } > unique_packages.txt - - # Store the consolidated package list - consolidated_packages=$(cat unique_packages.txt | grep -v "^Processing\|^Found\|^ Updated" | sort) - echo "vulnerable_packages<> $GITHUB_OUTPUT - echo "$consolidated_packages" >> $GITHUB_OUTPUT - echo "EOF" >> $GITHUB_OUTPUT - fi - - - name: Process vulnerable packages individually - if: steps.audit.outputs.has_vulnerabilities == 'true' - id: process_packages - run: | - # Build the JSON array in a variable first - json_data="[" - first_item=true - - # Store all vulnerability data for reference - all_vulns="${{ steps.audit.outputs.all_vulns_data }}" - - while IFS= read -r line; do - if [[ -n "$line" && $line =~ ([^=]+)==(.+) ]]; then - pkg_name="${BASH_REMATCH[1]}" - pkg_version="${BASH_REMATCH[2]}" - - echo "Processing package: $pkg_name -> $pkg_version" - - # Get current version from the first vulnerability entry - current_ver=$(echo "$all_vulns" | grep -m 1 "^$pkg_name," | cut -d',' -f2) - - # Get all vulnerability IDs for this package - vuln_ids=$(echo "$all_vulns" | grep "^$pkg_name," | cut -d',' -f3 | sort -u | paste -sd "," -) - - # Create signature specific to this package - pkg_signature=$(echo "$pkg_name-$pkg_version" | md5sum | cut -d ' ' -f1) - - echo " Current version: $current_ver" - echo " Vulnerabilities: $vuln_ids" - echo " Signature: $pkg_signature" - - # Add to JSON (with comma if not first) - if [ "$first_item" = "true" ]; then - first_item=false - else - json_data+="," - fi - - # Escape any special characters in the values - pkg_name_esc=$(echo "$pkg_name" | jq -R .) - pkg_version_esc=$(echo "$pkg_version" | jq -R .) - current_ver_esc=$(echo "$current_ver" | jq -R .) - vuln_ids_esc=$(echo "$vuln_ids" | jq -R .) - - # Build the JSON object with proper escaping - json_data+="{\"name\":${pkg_name_esc},\"version\":${pkg_version_esc},\"current_version\":${current_ver_esc},\"vuln_id\":${vuln_ids_esc},\"signature\":\"$pkg_signature\"}" - fi - done <<< "${{ steps.audit.outputs.vulnerable_packages }}" - - # Close the JSON array - json_data+="]" - - # Use the multiline delimiter syntax for GitHub Actions outputs - echo "package_data<> $GITHUB_OUTPUT - echo "$json_data" >> $GITHUB_OUTPUT - echo "EOF" >> $GITHUB_OUTPUT - - outputs: - has_vulnerabilities: ${{ steps.audit.outputs.has_vulnerabilities }} - package_data: ${{ steps.process_packages.outputs.package_data }} - all_vulns_data: ${{ steps.audit.outputs.all_vulns_data }} - - update_packages: - needs: build - if: needs.build.outputs.has_vulnerabilities == 'true' - runs-on: ubuntu-latest - strategy: - matrix: - package: ${{ fromJSON(needs.build.outputs.package_data) }} - # Allow other package updates to continue if one fails - fail-fast: false - # Limit concurrent jobs to avoid API rate limits - max-parallel: 5 - - steps: - - uses: actions/checkout@v4 - - - name: Set up uv - uses: astral-sh/setup-uv@v6 - with: - enable-cache: true - version: "0.7.12" - - - name: Check for existing PRs - id: check_prs - run: | - # Check for existing PRs with this package name - pkg_name="${{ matrix.package.name }}" - existing_pr=$(gh pr list --json number,title,body --search "in:title security update for $pkg_name" --jq '.[0]') - - if [[ -n "$existing_pr" ]]; then - pr_number=$(echo "$existing_pr" | jq -r '.number') - echo "Found existing PR #$pr_number for $pkg_name" - - # Check if PR contains an older version of the same package - pr_body=$(echo "$existing_pr" | jq -r '.body') - if echo "$pr_body" | grep -q "Package signature: ${{ matrix.package.signature }}"; then - echo "Found PR with identical package version - skipping" - echo "skip_pr_creation=true" >> $GITHUB_OUTPUT - exit 0 - fi - - # PR exists but for a different version - we'll close it and create new one - echo "PR exists for different version - will close and create new PR" - gh pr close $pr_number --comment "Closing in favor of PR with newer version ${pkg_name}==${matrix.package.version}" - fi - - echo "Will create new PR for ${pkg_name}==${{ matrix.package.version }}" - echo "skip_pr_creation=false" >> $GITHUB_OUTPUT - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - - - name: Configure Git - run: | - git config --global user.name 'github-actions[bot]' - git config --global user.email 'github-actions[bot]@users.noreply.github.com' - - - name: Update package - if: steps.check_prs.outputs.skip_pr_creation == 'false' - id: update - continue-on-error: true # Continue to cleanup step even if this fails - run: | - # Create a unique branch name for this package - branch_name="security-update-${{ matrix.package.name }}-${{ github.run_id }}" - echo "branch_name=$branch_name" >> $GITHUB_OUTPUT - - # Ensure we're on master and it's up-to-date - git fetch origin master - git checkout master - git pull origin master - - # Create new branch for this package only - git checkout -b $branch_name - - echo "Setting up uv environment..." - uv sync --frozen --all-extras - - # Update only this specific package in the lock file - echo "Updating ${{ matrix.package.name }} to ${{ matrix.package.version }}" - uv lock --upgrade-package "${{ matrix.package.name }}==${{ matrix.package.version }}" - - # Verify changes were made - if git diff --quiet uv.lock; then - echo "No changes detected in uv.lock file. This might indicate an issue with the update process." - exit 1 - fi - - # Commit changes - git add uv.lock - git commit -m "fix(security): update ${{ matrix.package.name }} to ${{ matrix.package.version }}" - - # Push to the remote branch - git push origin $branch_name + uvx pip-audit -r requirements.txt --disable-pip --desc off --format json > pip_audit_results.txt || true - - name: Create package-specific PR report with all vulnerabilities - if: steps.check_prs.outputs.skip_pr_creation == 'false' && steps.update.outcome == 'success' - id: create_report + - name: Process audit information run: | - # Get all vulnerability details for this package from the CSV - all_vulns="${{ needs.build.outputs.all_vulns_data }}" - - # Create PR description with comprehensive vulnerability information - { - echo "# Security Update: ${{ matrix.package.name }}" - echo "" - echo "This PR updates **${{ matrix.package.name }}** from version ${{ matrix.package.current_version }} to **${{ matrix.package.version }}** to fix the following security vulnerabilities:" - echo "" - - # List all vulnerabilities for this package - echo "## Vulnerability Details" - echo "" - echo "| Vulnerability ID | Affected Version | Fixed Version |" - echo "| --------------- | --------------- | ------------ |" - - # Parse the CSV data to extract vulnerabilities for this package - echo "$all_vulns" | grep -v "^pkg_name" | grep "^${{ matrix.package.name }}," | while IFS=, read -r pkg curr_ver vuln_id fixed_ver; do - # If the vulnerability is fixed by the version we're updating to, include it - echo "| $vuln_id | $curr_ver | $fixed_ver |" - done + # Avoid downloading and installing entire project and all dependencies + uv run --no-sync --isolated --with packaging runscripts/debug/process_vulnerabilities.py pip_audit_results.txt - echo "" - echo "Close and reopen this PR to trigger the CI/CD pipelines before merging." - echo "" + - name: Apply package updates + run: | + ./apply_security_upgrades.sh - echo "" - echo "" - } > pr_description.md - - cat pr_description.md - - name: Create Pull Request - if: steps.check_prs.outputs.skip_pr_creation == 'false' && steps.update.outcome == 'success' - id: create_pr - continue-on-error: true - run: | - gh pr create \ - --title "Security update for ${{ matrix.package.name }} to ${{ matrix.package.version }}" \ - --body-file pr_description.md \ - --base master \ - --head ${{ steps.update.outputs.branch_name }} - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - - - name: Cleanup on failure - if: | - steps.check_prs.outputs.skip_pr_creation == 'false' && - (steps.update.outcome == 'failure' || steps.create_pr.outcome == 'failure') && - steps.update.outputs.branch_name != '' - run: | - echo "Cleaning up branch due to workflow failure..." - branch_name="${{ steps.update.outputs.branch_name }}" - - # Check if branch exists before attempting to delete - if git ls-remote --heads origin $branch_name | grep -q $branch_name; then - echo "Deleting branch: $branch_name" - git push origin --delete $branch_name - else - echo "Branch $branch_name does not exist or was not created" - fi + uses: peter-evans/create-pull-request@v7 + with: + commit-message: | + fix(security): update package versions + sign-commits: true + title: | + Security updates + body-path: vulnerability_report.md + delete-branch: true + branch: security-updates + add-paths: uv.lock env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/README.md b/README.md index a000b1b55..bd79d7535 100644 --- a/README.md +++ b/README.md @@ -40,7 +40,7 @@ If your system has git, curl (or wget), and a C compiler On Ubuntu/Debian, apt can be used to install all three prerequisites: - sudo -s eval 'apt update && apt install git curl clang' + sudo -s eval 'apt update && apt install -y git curl clang' On MacOS, curl is preinstalled and git and clang come with the Xcode Command Line Tools: @@ -68,7 +68,7 @@ git clone https://github.com/CovertLab/vEcoli.git > a new directory called `vEcoli` in your current directory. To speed up > the clone and save disk space, add `--filter=blob:none` to the command. -2. [Follow these instructions](https://docs.astral.sh/uv/getting-started/installation/) +2. [Follow these "Standalone installer" instructions](https://docs.astral.sh/uv/getting-started/installation/) to install `uv`, our Python package and project manager of choice. 3. Close and reopen your terminal. @@ -84,6 +84,10 @@ uv sync --frozen --extra dev uv run pre-commit install ``` +> **Tip:** If uv is not connecting to the venv correctly, or you are running into an error with the +> `uv run pre-commit install` step, try running `rm -rf .venv` to remove the venv, then run +> `uv sync --frozen --extra dev` followed by `uv run pre-commit install` to reinstall the venv. + 5. Install `nextflow` [following these instructions](https://www.nextflow.io/docs/latest/install.html). If your system has `wget` but not `curl`, replace `curl` in the commands with `wget -qO-`. If you choose to install Java with SDKMAN!, after @@ -155,4 +159,4 @@ and contains more information about the model architecture, output, and workflow configuration. If you encounter an issue not addressed by the docs, feel free to create a GitHub issue, and we will -get back to you as soon as we can. +get back to you as soon as we can. \ No newline at end of file diff --git a/doc/experiments.rst b/doc/experiments.rst index cdc1c7efb..60846e59a 100644 --- a/doc/experiments.rst +++ b/doc/experiments.rst @@ -134,6 +134,7 @@ documented in :ref:`/workflows.rst`. # simulations run using ecoli/experiments/ecoli_master_sim.py. Workflows # run with runscripts/workflow.py generate initial seeds using the value # of a different configuration option named "lineage_seed". + # Both seed and lineage_seed are supposed to be integers. "seed": 0, # Special flags to enable mechanisms related to antibiotic resistance. # See API documentation for ecoli.library.sim_data.LoadSimData for more diff --git a/doc/gcloud.rst b/doc/gcloud.rst index 6080f579f..9f7553264 100644 --- a/doc/gcloud.rst +++ b/doc/gcloud.rst @@ -123,21 +123,28 @@ right service account and project. Next, install Git and clone the vEcoli reposi .. code-block:: bash - sudo apt update && sudo apt install git - git clone https://github.com/CovertLab/vEcoli.git + # zip and unzip necessary to install SDKMAN to get Java for nextflow + sudo apt update && sudo apt install -y git zip unzip + git clone https://github.com/CovertLab/vEcoli.git --filter=blob:none + cd vEcoli -Now follow the installation instructions from the README starting with -installing ``uv`` and finishing with installing Nextflow. +`Install uv `_, then +create a new virtual environment and install GCSFS: -.. note:: - Technically, the only requirements to run :mod:`runscripts.workflow` on Google Cloud - are Nextflow, Python 3.9+, and `GCSFS `_. - The workflow steps will be run inside Docker containers (see - :ref:`docker-images`). The other Python requirements can be - omitted for a more minimal installation. You will need to use - :ref:`interactive containers ` to run the model using - any interface other than :mod:`runscripts.workflow`, but this may be a good - thing for maximum reproducibility. +.. code-block:: bash + + source ~/.bashrc + uv venv + uv pip install gcsfs + +Run the following to automatically activate the virtual environment: + +.. code-block:: bash + + echo "source ~/vEcoli/.venv/bin/activate" >> ~/.bashrc + source ~/.bashrc + +Finally, `install Nextflow `_. ------------------ Create Your Bucket diff --git a/doc/hpc.rst b/doc/hpc.rst index e4cc1dcf1..85346b483 100644 --- a/doc/hpc.rst +++ b/doc/hpc.rst @@ -32,6 +32,72 @@ Setup .. note:: The following setup applies to members of the Covert Lab only. +Request a Sherlock Account +-------------------------- + +If you've never had a Sherlock account: Go to https://www.sherlock.stanford.edu/ and click on ``Request an Account`` + +.. note:: + Markus will have to approve this. + +If you've had a Sherlock account for a previous group: Email srcc-support@stanford.edu and ask them to move your account to mcovert, and CC Markus on the email and in the email body ask for Markus to give approval + +Additional Resources: Sherlock Documentation from Stanford +---------------------------------------------------------- + +* https://srcc.stanford.edu/workshops/sherlock-boarding-session +* https://www.sherlock.stanford.edu/docs/ + +Login to Sherlock +----------------- + +.. code-block:: bash + + ssh @login.sherlock.stanford.edu + # Type in Stanford Password + # Do the Duo authentication + # The following setup steps should be done using the Sherlock terminal + # NOTE that this is a LOGIN node, so no major computing should be done here + + # It is best to use a compute node for things like cloning the repo, running code, resetting lpad, etc + + srun -p mcovert --time=4:00:00 --cpus-per-task=2 --pty bash + + # srun is the command for launching a job step under Slurm + # -p or --partition specifies which partition (queue) to use, choose covert :D + # --time: sets the job's wall‑clock time limit + # --cpus-per-task specifies # CPU cores for each task in this job step + # --pty: allocates a pseudo‑terminal (TTY) to run an interactive session + # bash: launching a Bash shell + # When it finished, usually you can see your JOB ID in your shell + + # You can use scancel to abort your job step + scancel + +You can also refer to the Sherlock Documentation: https://www.sherlock.stanford.edu/docs/getting-started/connecting/ + +Clone the vEcoli Repository +---------------------------- + +1. Git clone the vEcoli repo to your Sherlock account: + +.. code-block:: bash + + git clone https://github.com/CovertLab/vEcoli.git + +If you have already created your branch, you can use: + +.. code-block:: bash + + # View all the branches (including remote branch) + git branch -a + + # Checkout to your own branch + git checkout + + # Validate your current branch + git branch + After cloning the model repository to your home directory, add the following lines to your ``~/.bash_profile``, then close and reopen your SSH connection: @@ -65,6 +131,18 @@ a workflow on Sherlock. To run scripts on Sherlock outside a workflow, see :ref:`sherlock-interactive`. To run scripts on Sherlock through a SLURM batch script, see :ref:`sherlock-noninteractive`. +.. tip:: + * You can use ``nano`` as text editor: + + .. code-block:: bash + + nano ~/.bash_profile + # After writing, you can use Ctrl+O to write out, Enter to confirm, and Ctrl+X to exit + + * If you choose to use ``vim``, press ``i`` for insert, and press ``Esc``, then type ``:wq`` and Enter for writing out + * Before running the ``python3`` to set up the env, ensure you are in the vEcoli repo + * It usually takes time to run first job + .. note:: The above setup is sufficient to run workflows on Sherlock. However, if you have a compelling reason to update the shared Nextflow or HyperQueue binaries, @@ -73,6 +151,17 @@ To run scripts on Sherlock through a SLURM batch script, see :ref:`sherlock-noni 1. Nextflow: ``NXF_EDGE=1 nextflow self-update`` 2. HyperQueue: See :ref:`hq-info`. + Then, reset the permissions of the updated binaries with ``chmod 777 *``. + +.. warning:: + + Before building your own config file and running an experiment, remember: + + Python scripts (other than runscripts/workflow.py) **WILL NOT** run on Sherlock directly. + This includes the standalone ParCa, simulation, and analysis run scripts. + Instead, these scripts can be run inside an :ref:`sherlock-interactive` (ideal for script development or debugging) + or :ref:`sherlock-noninteractive` (ideal for longer or more resource-intensive scripts that do not require user input). + .. _sherlock-config: Configuration @@ -104,8 +193,11 @@ keys in your configuration JSON (note the top-level ``sherlock`` key): In addition to these options, you **MUST** set the emitter output directory (see description of ``emitter_arg`` in :ref:`json_config`) to a path with -enough space to store your workflow outputs. We recommend setting this to -a location in your ``$SCRATCH`` directory (e.g. ``/scratch/users/{username}/out``). +enough space to store your workflow outputs. + +.. important:: + We recommend setting ``emitter_arg`` to a location in your ``$SCRATCH`` directory (e.g. ``"out_dir": "/scratch/users/{username}/out"``), + since ``$HOME`` only has a pretty small storage limit (run ``sh_quota`` to view). If using the Parquet emitter and ``threaded`` is not set to false under ``emitter_arg``, a warning will be printed suggesting that you set ``threaded`` @@ -136,8 +228,9 @@ in the path to your config JSON. .. warning:: Remember to use ``python3`` to start workflows instead of ``python``. + This command is supposed to run on **login node**, which means there is no need to use ``srun`` to request a **compute node**. + If there is trouble with permission denied for nextflow (you can use ``nextflow -version`` to check out), you can try ``chmod a+rwx`` -This command should be run on a login node (no need to request a compute node). If ``build_image`` is true in your config JSON, the terminal will report that a SLURM job was submitted to build the container image. When the image build job starts, the terminal will report the build progress. @@ -149,10 +242,19 @@ job starts, the terminal will report the build progress. Do not make any changes to your cloned repository or close your SSH connection until the build has finished. -Once the build has finished, the terminal will report that a SLURM job +Once the build has finished, the terminal will report that a **SLURM job** was submitted for the Nextflow workflow orchestrator before exiting back to the shell. At this point, you are free to close your connection, -start additional workflows, etc. Unlike workflows run locally, Sherlock's +start additional workflows, etc. You can use ``squeue`` to view the status of your SLURM job: + +.. code-block:: bash + + # View by job + squeue -j + # View by user + squeue -u + +Unlike workflows run locally, Sherlock's containerized workflows mean any changes made to the repository after the container image has been built will not affect the running workflow. @@ -235,6 +337,8 @@ More specifically, users who wish to debug a failed workflow job should: Any changes that you make to ``/vEcoli`` inside the container are discarded when the container terminates. +Moreover, if you want to exit the interactive image, just type ``exit`` command. + To start an interactive container that reflects the current state of your cloned repository, navigate to your cloned repository and run the above command with the ``-d`` flag to start a "development" container: @@ -286,11 +390,77 @@ to include one of the following directives at the top of your script: - ``#SBATCH --partition=owners,normal``: Uses either the ``owners`` or ``normal`` partition. This is the recommended option for the vast majority of scripts. +Following is a sample of sbatch scripts for requiring more resources to analysis simulation results: + +.. code-block:: bash + + #!/usr/bin/bash + #SBATCH --job-name=analysis_job + #SBATCH --output=analysis_job.%j.out + #SBATCH --error=analysis_job.%j.err + #SBATCH --time=20:00 + #SBATCH --ntasks=1 + #SBATCH --partition=owners,normal + #SBATCH --cpus-per-task=4 + #SBATCH --mem=64GB + + srun runscripts/container/interactive.sh -i -a -c "python runscripts/analysis.py --config " + +Then, use ``sbatch`` to submit the job: + +.. code-block:: bash + + sbatch .sh + +The ``.err`` and ``.out`` files will be created in the same directory as the sbatch script. + Just as with interactive containers, to run scripts directly from your cloned repository and not the snapshot, add the ``-d`` flag and drop the ``/vEcoli/`` prefix from script names. Note that changing files in your cloned repository may affect SLURM batch jobs submitted with this flag. +.. _Download Results to Local from Sherlock: + +Download Results to Local from Sherlock +==================================== + +It's recommended to turn to +`Sherlock's Data Transfer documentation `_ +for details on transferring files to and from your local machine. + +Following are common methods ``scp`` and ``rsync``: + +``scp`` is convenient for downloading files from the cluster. You can simply execute the following on your **local terminal**: + +.. code-block:: bash + + # -r for recursively duplicate the whole repo: + scp -r @login.sherlock.stanford.edu:/path/to/remote/folder /path/to/local/destination + + # If you only want to download single file: + scp @login.sherlock.stanford.edu:/path/to/remote/file /path/to/local/destination/ + +In practice, usually we want to get the analytical results for our simulation. +Due to the report files being HTML files typically, we can turn to shell wildcard and use ``rsync`` with ``include/exclude`` filters: + +.. code-block:: bash + + # Recursively downloads all .html files under the specific directory on Sherlock + # to your local machine while preserving the subdirectory structure: + + rsync -av --prune-empty-dirs \ + --include='*/' --include='*.html' --exclude='*' \ + @login.sherlock.stanford.edu:/path/to/remote/folder /path/to/local/destination + + # --include='*/': Keeps all directories, allowing rsync to traverse into subdirectories + # --include='*.html': Includes only .html files + # --exclude='*': Excludes everything else + # -a: Archive mode (preserves metadata) + # -v: Verbose output + # --prune-empty-dirs: Avoids creating empty directories on the local machine + +Both ``scp`` and ``rsync`` will require your password and Duo validation. + .. _other-cluster: -------------- diff --git a/ecoli/analysis/multigeneration/catalyst_count.py b/ecoli/analysis/multigeneration/catalyst_count.py new file mode 100644 index 000000000..470d1f766 --- /dev/null +++ b/ecoli/analysis/multigeneration/catalyst_count.py @@ -0,0 +1,27 @@ +""" +Visualize catalyst counts over time for specified BioCyc reactions across generations. +For each specific BioCyc ID reaction, this scripts will add all the catalysts which catalyse it: +```number of catalysts = sum(number of catalysts[i])``` + +Supports two visualization modes: +1. 'grid' mode: Each row represents a variant, each column represents a reaction's catalysts +2. 'stacked' mode: Each reaction's catalysts get their own chart, variants shown as different colored lines + +You can specify the reactions and layout using parameters: + "catalyst_count": { + # Required: specify BioCyc reaction IDs to visualize + "BioCyc_ID": ["Name1", "Name2", ...], + # Optional: specify generations to visualize + # If not specified, all generations will be used + "generation": [1, 2, ...], + # Optional: specify layout mode ('grid' or 'stacked') + # Default: 'stacked' + "layout": "stacked" # or "grid" + } + +This script is the dummy version of ecoli.analysis.multivariant.catalyst_count, you can turn to origin file for more detail +""" + +from ecoli.analysis.multivariant.catalyst_count import plot + +__all__ = ["plot"] diff --git a/ecoli/analysis/multigeneration/fba_flux.py b/ecoli/analysis/multigeneration/fba_flux.py new file mode 100644 index 000000000..8496f062c --- /dev/null +++ b/ecoli/analysis/multigeneration/fba_flux.py @@ -0,0 +1,22 @@ +""" +Visualize FBA reaction fluxes over time for specified reactions with net flux calculation across multiple variants. + +Supports two visualization modes: +1. 'grid' mode: Each row represents a variant, each column represents a reaction +2. 'stacked' mode: Each reaction gets its own chart, variants shown as different colored lines + +You can specify the reactions and layout using parameters: + "fba_flux": { + # Required: specify BioCyc reaction IDs to visualize + "BioCyc_ID": ["Name1", "Name2", ...], + # Optional: specify layout mode ('grid' or 'stacked') + # Default: 'stacked' + "layout": "stacked" # or "grid" + } + +This script is the dummy version of ecoli.analysis.multivariant.fba_flux, you can turn to origin file for more detail +""" + +from ecoli.analysis.multivariant.fba_flux import plot + +__all__ = ["plot"] diff --git a/ecoli/analysis/multigeneration/fba_flux_heatmap.py b/ecoli/analysis/multigeneration/fba_flux_heatmap.py new file mode 100644 index 000000000..e15a2a1c1 --- /dev/null +++ b/ecoli/analysis/multigeneration/fba_flux_heatmap.py @@ -0,0 +1,342 @@ +""" +Visualize FBA reaction fluxes as a heatmap across multiple generations. + +You can specify the reactions to visualize using parameters in params: + "fba_flux_heatmap": { + "BioCyc_ID": ["Name1", "Name2", ...], # Required: reactions to analyze + "normalized_reaction": "NormalizationReactionName", # Optional: reaction for normalization + "generation": [1, 2, ...] # Optional: generations to analyze (default: all) + } + +This script will: +1. Find all reactions matching the specified BioCyc IDs using efficient SQL-based approach +2. Calculate net flux for each reaction (forward - reverse) directly in SQL +3. Calculate time-averaged net flux for each generation +4. Optionally normalize fluxes relative to a reference reaction (flux/reference_flux * 100) +5. Create a heatmap with generations on y-axis and reactions on x-axis + +Normalization formula: normalized_flux = (reaction_flux / reference_reaction_flux) * 100 +If normalized_reaction is not specified, raw flux values are used. +""" + +import altair as alt +import os +from typing import Any + +import polars as pl +from duckdb import DuckDBPyConnection +import pandas as pd + +from ecoli.library.parquet_emitter import field_metadata +from ecoli.analysis.utils import ( + create_base_to_extended_mapping, + build_flux_calculation_sql, +) + + +def plot( + params: dict[str, Any], + conn: DuckDBPyConnection, + history_sql: str, + config_sql: str, + success_sql: str, + sim_data_dict: dict[str, dict[int, str]], + validation_data_paths: list[str], + outdir: str, + variant_metadata: dict[str, dict[int, Any]], + variant_names: dict[str, str], +): + """Create a heatmap visualization of FBA reaction net fluxes across multiple generations.""" + + # Get parameters + biocyc_ids = params.get("BioCyc_ID", []) + normalized_reaction = params.get("normalized_reaction", None) + target_generations = params.get("generation", None) + + if not biocyc_ids: + print( + "[ERROR] No BioCyc_ID found in params. Please specify reaction IDs to visualize." + ) + return None + + if isinstance(biocyc_ids, str): + biocyc_ids = [biocyc_ids] + + print(f"[INFO] Creating heatmap for {len(biocyc_ids)} reactions: {biocyc_ids}") + if normalized_reaction: + print(f"[INFO] Normalizing relative to: {normalized_reaction}") + + # All reactions we need to analyze (including normalization reaction if specified) + all_reactions = biocyc_ids.copy() + if normalized_reaction and normalized_reaction not in all_reactions: + all_reactions.append(normalized_reaction) + + # Create base to extended reaction mapping + base_to_extended_mapping = create_base_to_extended_mapping(sim_data_dict) + if not base_to_extended_mapping: + print("[ERROR] Could not create base to extended reaction mapping") + return None + + # Load reaction IDs from config + try: + all_reaction_ids = field_metadata( + conn, config_sql, "listeners__fba_results__reaction_fluxes" + ) + print(f"[INFO] Total reactions in sim_data: {len(all_reaction_ids)}") + except Exception as e: + print(f"[ERROR] Error loading reaction IDs: {e}") + return None + + # Build SQL query for efficient flux calculation + flux_calculation_sql, valid_reactions = build_flux_calculation_sql( + all_reactions, base_to_extended_mapping, all_reaction_ids, history_sql + ) + + if not flux_calculation_sql or not valid_reactions: + print("[ERROR] Could not build flux calculation SQL") + return None + + print(f"[INFO] Processing {len(valid_reactions)} valid reactions") + + # Check if normalization reaction is valid + if normalized_reaction and normalized_reaction not in valid_reactions: + print( + f"[WARNING] Normalization reaction {normalized_reaction} not found. Proceeding without normalization." + ) + normalized_reaction = None + + # Filter valid BioCyc IDs for visualization + # Exclude normalization reaction if it's not in the original list + valid_biocyc_ids = [rxn for rxn in biocyc_ids if rxn in valid_reactions] + + if not valid_biocyc_ids: + print("[ERROR] No valid BioCyc IDs found for visualization") + return None + + print(f"[INFO] Visualizing {len(valid_biocyc_ids)} reactions in heatmap") + + # Execute the optimized SQL query + try: + df = conn.sql(flux_calculation_sql).pl() + print(f"[INFO] Loaded flux data with {df.height} time steps") + except Exception as e: + print(f"[ERROR] Error executing flux calculation SQL: {e}") + return None + + if df.height == 0: + print("[ERROR] No data found") + return None + + # Filter by specified generations if provided + if target_generations is not None: + print(f"[INFO] Target generations: {target_generations}") + df = df.filter(pl.col("generation").is_in(target_generations)) + + if df.height == 0: + print("[ERROR] No data found for specified generations") + return None + + # Print generation information + unique_generations = sorted(df["generation"].unique().to_list()) + print(f"[INFO] Found {len(unique_generations)} generations: {unique_generations}") + + # Calculate time-averaged net flux for each generation and reaction + print("[INFO] Calculating time-averaged fluxes for each generation...") + + heatmap_data = [] + + for generation in unique_generations: + generation_data = df.filter(pl.col("generation") == generation) + + generation_averages = {} + for reaction_id in valid_reactions: + net_flux_col = f"{reaction_id}_net_flux" + if net_flux_col in generation_data.columns: + avg_flux = generation_data[net_flux_col].mean() + generation_averages[reaction_id] = avg_flux + else: + generation_averages[reaction_id] = 0.0 + + # Apply normalization if specified + if normalized_reaction and normalized_reaction in generation_averages: + norm_flux = generation_averages[normalized_reaction] + if abs(norm_flux) > 1e-10: # Avoid division by zero + print( + f"[INFO] Gen {generation}: Normalizing by {normalized_reaction} = {norm_flux:.6f}" + ) + for reaction_id in valid_biocyc_ids: + if reaction_id in generation_averages: + normalized_value = ( + generation_averages[reaction_id] / norm_flux + ) * 100 + heatmap_data.append( + { + "Generation": str(generation), + "Reaction": reaction_id, + "Net_Flux": normalized_value, + "Raw_Flux": generation_averages[reaction_id], + "Normalization_Flux": norm_flux, + } + ) + else: + print( + f"[WARNING] Gen {generation}: Normalization reaction flux is zero, using raw values" + ) + for reaction_id in valid_biocyc_ids: + if reaction_id in generation_averages: + heatmap_data.append( + { + "Generation": str(generation), + "Reaction": reaction_id, + "Net_Flux": generation_averages[reaction_id], + "Raw_Flux": generation_averages[reaction_id], + "Normalization_Flux": 0.0, + } + ) + else: + # No normalization + for reaction_id in valid_biocyc_ids: + if reaction_id in generation_averages: + heatmap_data.append( + { + "Generation": str(generation), + "Reaction": reaction_id, + "Net_Flux": generation_averages[reaction_id], + "Raw_Flux": generation_averages[reaction_id], + "Normalization_Flux": None, + } + ) + + if not heatmap_data: + print("[ERROR] No heatmap data could be generated") + return None + + heatmap_df = pd.DataFrame(heatmap_data) + print(f"[INFO] Generated heatmap data with {len(heatmap_data)} data points") + + # Print some statistics about the flux data + flux_stats = { + "min": heatmap_df["Net_Flux"].min(), + "max": heatmap_df["Net_Flux"].max(), + "mean": heatmap_df["Net_Flux"].mean(), + "std": heatmap_df["Net_Flux"].std(), + } + print( + f"[INFO] Flux statistics: min={flux_stats['min']:.6f}, max={flux_stats['max']:.6f}, " + f"mean={flux_stats['mean']:.6f}, std={flux_stats['std']:.6f}" + ) + + # Create the heatmap using Altair + print("[INFO] Creating heatmap visualization...") + + # Determine color scale based on data range + flux_min = heatmap_df["Net_Flux"].min() + flux_max = heatmap_df["Net_Flux"].max() + flux_abs_max = max(abs(flux_min), abs(flux_max)) + + # Use diverging color scheme if data crosses zero, sequential otherwise + if flux_min < 0 and flux_max > 0: + color_scale = alt.Scale( + scheme="redblue", domain=[-flux_abs_max, flux_abs_max], type="linear" + ) + print("[INFO] Using diverging color scheme (data crosses zero)") + else: + color_scale = alt.Scale(scheme="viridis", type="linear") + print("[INFO] Using sequential color scheme") + + # Create heatmap + flux_title = ( + "Net Flux (% of norm)" if normalized_reaction else "Net Flux (mmol/gDW/hr)" + ) + chart_title = f"FBA Net Flux Heatmap: {len(unique_generations)} Generations x {len(valid_biocyc_ids)} Reactions" + if normalized_reaction: + chart_title += f" (Normalized by {normalized_reaction})" + + # Build tooltip list conditionally + tooltip_list = [ + "Generation:N", + "Reaction:N", + alt.Tooltip("Net_Flux:Q", format=".4f", title="Net Flux"), + alt.Tooltip("Raw_Flux:Q", format=".6f", title="Raw Flux"), + ] + if normalized_reaction: + tooltip_list.append( + alt.Tooltip("Normalization_Flux:Q", format=".6f", title="Norm Flux") + ) + + # Calculate appropriate chart dimensions + chart_width = max(400, len(valid_biocyc_ids) * 60) + chart_height = max(300, len(unique_generations) * 50) + + heatmap = ( + alt.Chart(heatmap_df) + .mark_rect(stroke="white", strokeWidth=1) + .encode( + x=alt.X( + "Reaction:N", + title="BioCyc Reaction ID", + sort=valid_biocyc_ids, + axis=alt.Axis(labelAngle=-45), + ), + y=alt.Y( + "Generation:N", + title="Generation", + sort=alt.SortField(field="Generation", order="ascending"), + ), + color=alt.Color("Net_Flux:Q", title=flux_title, scale=color_scale), + tooltip=tooltip_list, + ) + .properties( + width=chart_width, + height=chart_height, + title=alt.TitleParams(text=chart_title, fontSize=14, anchor="start"), + ) + ) + + # Add text annotations on the heatmap cells (only if not too many cells) + total_cells = len(valid_biocyc_ids) * len(unique_generations) + if total_cells <= 100: # Only add text for smaller heatmaps to avoid clutter + text_annotations = ( + alt.Chart(heatmap_df) + .mark_text(baseline="middle", fontSize=10, fontWeight="bold") + .encode( + x=alt.X("Reaction:N", sort=valid_biocyc_ids), + y=alt.Y( + "Generation:N", + sort=alt.SortField(field="Generation", order="ascending"), + ), + text=alt.Text("Net_Flux:Q", format=".2f"), + color=alt.condition( + f"datum.Net_Flux > {flux_abs_max * 0.5}", + alt.value("white"), + alt.value("black"), + ), + ) + ) + # Combine heatmap and text + final_chart = heatmap + text_annotations + print("[INFO] Added text annotations to heatmap") + else: + final_chart = heatmap + print("[INFO] Skipped text annotations (too many cells for readability)") + + # Save the plot + output_path = os.path.join(outdir, "fba_flux_heatmap.html") + final_chart.save(output_path) + print(f"[INFO] Saved heatmap visualization to: {output_path}") + + # Save the underlying data as CSV for reference + csv_path = os.path.join(outdir, "fba_flux_heatmap_data.csv") + heatmap_df.to_csv(csv_path, index=False) + print(f"[INFO] Saved heatmap data to: {csv_path}") + + # Print summary of results + print("[INFO] Heatmap visualization completed successfully!") + print(f"[INFO] - Generations analyzed: {len(unique_generations)}") + print(f"[INFO] - Reactions visualized: {len(valid_biocyc_ids)}") + print( + f"[INFO] - Normalization: {'Yes (' + normalized_reaction + ')' if normalized_reaction else 'No'}" + ) + print(f"[INFO] - Total data points: {len(heatmap_data)}") + + return final_chart diff --git a/ecoli/analysis/multigeneration/fba_flux_oscillating.py b/ecoli/analysis/multigeneration/fba_flux_oscillating.py new file mode 100644 index 000000000..794b591f2 --- /dev/null +++ b/ecoli/analysis/multigeneration/fba_flux_oscillating.py @@ -0,0 +1,834 @@ +""" +This script preprocesses FBA flux data by mapping extended reactions to base reactions, +computes net fluxes for base reactions (forward extended - reverse extended), +identifies oscillating base reactions (those that take both positive and negative values), +computes dynamic metrics for oscillating reactions (positive/negative ratios, oscillation times/frequency), +and creates a semi-log heat scatter plot (with marginal density curves) focused on oscillating reactions. + +Modified to work with DuckDB connection and SQL queries instead of direct file loading. +All outputs are saved to outdir. + +Key changes from the original: +1. Removed unified directional combined heat/scatter plotting (no Always Positive / Always Negative / Always Zero plots). +2. Removed per-category statistics for always positive / always negative / always zero; focus is only on oscillating reactions. +3. Added a semi-log heat scatter plot for oscillating reactions: + - x axis: oscillation_frequency = oscillation_times / total_timepoints + where oscillation_times is the number of sign changes in the sequence after removing near-zero points. + - y axis: log10(positive_ratio / negative_ratio), where ratios are computed using epsilon thresholds. + - Both axes have marginal density curves (KDE). + - All comparisons to zero use eps: positive if > eps, negative if < -eps, zero if abs <= eps. + +eps can be set use params. +""" + +import os +from typing import Any, Dict, Tuple + +import numpy as np +import pandas as pd +from duckdb import DuckDBPyConnection +from scipy.stats import gaussian_kde +import plotly.graph_objects as go +from plotly.subplots import make_subplots + +from ecoli.library.parquet_emitter import field_metadata +from ecoli.analysis.utils import create_base_to_extended_mapping + + +def load_fba_data( + conn: DuckDBPyConnection, history_sql: str, config_sql: str, sim_data_dict: dict +) -> Tuple[pd.DataFrame, dict]: + """ + Load FBA flux data using DuckDB connection and SQL queries. + + Returns: + - df: DataFrame with flux data (time points x extended reactions, with 'time' column included) + - metadata: Dictionary with experiment metadata + """ + print("[INFO] Loading FBA flux data via SQL...") + + try: + base_to_extended_mapping = create_base_to_extended_mapping(sim_data_dict) + if not base_to_extended_mapping: + raise Exception("Could not create base to extended reaction mapping") + + reaction_ids = field_metadata( + conn, config_sql, "listeners__fba_results__reaction_fluxes" + ) + print(f"[INFO] Total reactions in sim_data: {len(reaction_ids)}") + + required_columns = [ + "time", + "generation", + "listeners__fba_results__reaction_fluxes", + ] + + sql = f""" + SELECT {", ".join(required_columns)} + FROM ({history_sql}) + ORDER BY generation, time + """ + + df_pl = conn.sql(sql).pl() + + if df_pl.is_empty(): + raise Exception("No data found") + print(f"[INFO] Loaded data with {df_pl.height} time steps") + + flux_matrix = df_pl["listeners__fba_results__reaction_fluxes"].to_numpy() + flux_matrix = np.array([np.array(row) for row in flux_matrix]) + + df = pd.DataFrame(flux_matrix, columns=reaction_ids) + + time_data = df_pl.select(["time"]).to_pandas() + df = pd.concat([time_data, df], axis=1) + + # Drop initial time point where time == 0 + df = df[df["time"] != 0].reset_index(drop=True) + + metadata = { + "n_extended_reactions": len(reaction_ids), + "n_timepoints": len(df), + "extended_reaction_names": reaction_ids, + "base_to_extended_mapping": base_to_extended_mapping, + } + + print("[INFO] Successfully loaded data:") + print(f" - Time points: {len(df)}") + print(f" - Extended reactions: {len(reaction_ids)}") + print(f" - Base reaction mapping entries: {len(base_to_extended_mapping)}") + + return df, metadata + + except Exception as e: + print(f"[ERROR] Failed to load data: {str(e)}") + raise + + +def map_extended_to_base_reactions(extended_reactions, base_to_extended_mapping): + """ + Map extended reactions to base reactions and identify forward/reverse relationships. + Returns: + - base_reaction_mapping: Dict with base reaction info including forward/reverse extended reactions + - extended_to_base_map: Dict mapping each extended reaction to its base reaction + """ + # Create reverse mapping from extended to base + extended_to_base_map = {} + for base_rxn, extended_list in base_to_extended_mapping.items(): + for extended_rxn in extended_list: + extended_to_base_map[extended_rxn] = base_rxn + + base_reaction_mapping = {} + + for extended_reaction in extended_reactions: + base_reaction = extended_to_base_map.get(extended_reaction) + + if base_reaction is None: + print( + f"[WARNING] No base reaction found for extended reaction: {extended_reaction}" + ) + continue + + if base_reaction not in base_reaction_mapping: + base_reaction_mapping[base_reaction] = { + "forward_extended": [], + "reverse_extended": [], + "all_extended": [], + } + + if extended_reaction.endswith(" (reverse)"): + base_reaction_mapping[base_reaction]["reverse_extended"].append( + extended_reaction + ) + else: + base_reaction_mapping[base_reaction]["forward_extended"].append( + extended_reaction + ) + + base_reaction_mapping[base_reaction]["all_extended"].append(extended_reaction) + + print("[INFO] Base reaction mapping results:") + print(f" - Total base reactions: {len(base_reaction_mapping)}") + print(f" - Extended reactions mapped: {len(extended_to_base_map)}") + + # Print concise distribution summary (only informative) + forward_only = sum( + 1 + for info in base_reaction_mapping.values() + if len(info["forward_extended"]) > 0 and len(info["reverse_extended"]) == 0 + ) + reverse_only = sum( + 1 + for info in base_reaction_mapping.values() + if len(info["forward_extended"]) == 0 and len(info["reverse_extended"]) > 0 + ) + both_directions = sum( + 1 + for info in base_reaction_mapping.values() + if len(info["forward_extended"]) > 0 and len(info["reverse_extended"]) > 0 + ) + + print(f" - Base reactions with forward extended only: {forward_only}") + print(f" - Base reactions with reverse extended only: {reverse_only}") + print( + f" - Base reactions with both forward and reverse extended: {both_directions}" + ) + + return base_reaction_mapping, extended_to_base_map + + +def compute_base_reaction_fluxes(flux_df: pd.DataFrame, base_reaction_mapping: dict): + """ + Compute base reaction net fluxes (forward - reverse). + Returns: + - base_flux_df: DataFrame with base reaction net flux values (columns = base reactions) + - base_reaction_details: dict with details per base reaction + """ + base_flux_data = {} + base_reaction_details = {} + + for base_reaction, info in base_reaction_mapping.items(): + forward_extended = info["forward_extended"] + reverse_extended = info["reverse_extended"] + + forward_flux = pd.Series(0.0, index=flux_df.index) + if forward_extended: + for ext_reaction in forward_extended: + if ext_reaction in flux_df.columns: + forward_flux = forward_flux.add( + flux_df[ext_reaction], fill_value=0.0 + ) + + reverse_flux = pd.Series(0.0, index=flux_df.index) + if reverse_extended: + for ext_reaction in reverse_extended: + if ext_reaction in flux_df.columns: + reverse_flux = reverse_flux.add( + flux_df[ext_reaction], fill_value=0.0 + ) + + net_flux = forward_flux - reverse_flux + base_flux_data[base_reaction] = net_flux + + base_reaction_details[base_reaction] = { + "forward_extended": forward_extended, + "reverse_extended": reverse_extended, + "n_forward_extended": len(forward_extended), + "n_reverse_extended": len(reverse_extended), + "total_extended": len(info["all_extended"]), + } + + base_flux_df = pd.DataFrame(base_flux_data) + + print("[INFO] Base reaction flux computation results:") + print(f" - Base reactions computed: {len(base_flux_df.columns)}") + print(f" - Time points: {len(base_flux_df)}") + + return base_flux_df, base_reaction_details + + +def categorize_base_reactions_by_flux_behavior( + base_flux_df: pd.DataFrame, eps: float = 1e-30 +): + """ + Categorize base reactions by behavior using epsilon threshold. + Only returns oscillating list in informative print; full categories still returned (for possible downstream needs). + """ + always_positive = [] + always_negative = [] + oscillating = [] + always_zero = [] + base_reaction_categories = {} + + for base_reaction in base_flux_df.columns: + flux_values = base_flux_df[base_reaction].values + + has_positive = np.any(flux_values > eps) + has_negative = np.any(flux_values < -eps) + has_zero = np.any(np.abs(flux_values) <= eps) + + min_flux = flux_values.min() + max_flux = flux_values.max() + max_abs_flux = np.max(np.abs(flux_values)) + + if max_abs_flux <= eps: + always_zero.append(base_reaction) + category = "always_zero" + elif not has_negative and has_positive: + always_positive.append(base_reaction) + category = "always_positive" + elif not has_positive and has_negative: + always_negative.append(base_reaction) + category = "always_negative" + elif has_positive and has_negative: + oscillating.append(base_reaction) + category = "oscillating" + else: + always_zero.append(base_reaction) + category = "always_zero" + + base_reaction_categories[base_reaction] = { + "category": category, + "min_flux": min_flux, + "max_flux": max_flux, + "max_abs_flux": max_abs_flux, + "has_positive": has_positive, + "has_negative": has_negative, + "has_zero": has_zero, + } + + print("\n[INFO] Base reaction categorization by flux behavior:") + print(f" - Oscillating (changes sign): {len(oscillating)}") + + return ( + always_positive, + always_negative, + oscillating, + always_zero, + base_reaction_categories, + ) + + +def compute_oscillating_dynamics_metrics( + base_flux_df: pd.DataFrame, oscillating_reactions: list, eps: float = 1e-30 +) -> pd.DataFrame: + """ + Compute dynamic metrics for oscillating reactions: + - positive_ratio, negative_ratio, zero_ratio (all relative to total timepoints) + - oscillation_times: number of sign changes in the non-zero filtered sequence + - oscillation_frequency = oscillation_times / total_timepoints + - log_pos_neg_ratio = log10(positive_ratio / negative_ratio), handled safely with eps + """ + oscillating_data = [] + + for base_reaction in oscillating_reactions: + flux_values = base_flux_df[base_reaction].values + n_timepoints = len(flux_values) + + positive_count = np.sum(flux_values > eps) + negative_count = np.sum(flux_values < -eps) + zero_count = np.sum(np.abs(flux_values) <= eps) + + positive_ratio = positive_count / n_timepoints + negative_ratio = negative_count / n_timepoints + zero_ratio = zero_count / n_timepoints + + # Remove near-zero values for oscillation counting + non_zero_mask = np.abs(flux_values) > eps + non_zero_flux = flux_values[non_zero_mask] + + if len(non_zero_flux) <= 1: + oscillation_times = 0 + oscillation_frequency = 0.0 + else: + # Count sign changes in consecutive non-zero sequence + signs = np.sign(non_zero_flux) + sign_changes = np.sum(signs[1:] != signs[:-1]) + oscillation_times = int(sign_changes) + oscillation_frequency = oscillation_times / n_timepoints + + # Compute log ratio safely + if negative_ratio > 0: + log_pos_neg_ratio = np.log10(positive_ratio / negative_ratio) + elif positive_ratio > 0: + log_pos_neg_ratio = np.log10(positive_ratio / eps) + else: + # Both zero (shouldn't be oscillating) -> fallback 0 + log_pos_neg_ratio = 0.0 + + oscillating_data.append( + { + "base_reaction": base_reaction, + "positive_ratio": positive_ratio, + "negative_ratio": negative_ratio, + "zero_ratio": zero_ratio, + "positive_count": int(positive_count), + "negative_count": int(negative_count), + "zero_count": int(zero_count), + "oscillation_times": oscillation_times, + "oscillation_frequency": oscillation_frequency, + "log_pos_neg_ratio": log_pos_neg_ratio, + "n_timepoints": n_timepoints, + "non_zero_points": int(len(non_zero_flux)), + "min_flux": float(np.min(flux_values)), + "max_flux": float(np.max(flux_values)), + "max_abs_flux": float(np.max(np.abs(flux_values))), + } + ) + + oscillating_metrics = pd.DataFrame(oscillating_data) + + if len(oscillating_metrics) > 0: + print( + f"\n[INFO] Oscillating dynamics metrics computed for {len(oscillating_metrics)} reactions:" + ) + print( + f" - Oscillation frequency range: {oscillating_metrics['oscillation_frequency'].min():.4f} to {oscillating_metrics['oscillation_frequency'].max():.4f}" + ) + print( + f" - Log(pos/neg ratio) range: {oscillating_metrics['log_pos_neg_ratio'].min():.4f} to {oscillating_metrics['log_pos_neg_ratio'].max():.4f}" + ) + print( + f" - Positive ratio range: {oscillating_metrics['positive_ratio'].min():.4f} to {oscillating_metrics['positive_ratio'].max():.4f}" + ) + print( + f" - Negative ratio range: {oscillating_metrics['negative_ratio'].min():.4f} to {oscillating_metrics['negative_ratio'].max():.4f}" + ) + else: + print("\n[WARN] No oscillating metrics computed (empty list).") + + return oscillating_metrics + + +def create_oscillating_dynamics_plot( + oscillating_metrics: pd.DataFrame, + base_reaction_details: Dict[str, dict], + outdir: str, + eps: float = 1e-30, +): + """ + Create a semi-log heat scatter plot for oscillating reactions with marginal density curves. + - X-axis: oscillation_frequency + - Y-axis: log10(positive_ratio / negative_ratio) + Saves HTML file to outdir. + """ + if oscillating_metrics is None or len(oscillating_metrics) == 0: + print("[WARNING] No oscillating reactions found for dynamics plot.") + return + + # Prepare data + x = oscillating_metrics["oscillation_frequency"] + y = oscillating_metrics["log_pos_neg_ratio"] + + # Compute density for scatter coloring + if len(oscillating_metrics) > 1: + xy = np.vstack([x, y]) + try: + density = gaussian_kde(xy)(xy) + except Exception: + # If KDE fails due to singular matrix etc., fallback to 1D product KDE + density = np.ones(len(x)) + else: + density = np.array([1.0]) + + # Hover text + hover_text = [] + for idx, row in oscillating_metrics.iterrows(): + base_reaction = row["base_reaction"] + details = base_reaction_details.get(base_reaction, {}) + forward_ext = details.get("forward_extended", []) + reverse_ext = details.get("reverse_extended", []) + + ext_info = f"Forward: {len(forward_ext)} extended, Reverse: {len(reverse_ext)} extended" + if len(forward_ext) <= 3: + ext_info += ( + f"
Forward: {', '.join(forward_ext) if forward_ext else 'None'}" + ) + if len(reverse_ext) <= 3: + ext_info += ( + f"
Reverse: {', '.join(reverse_ext) if reverse_ext else 'None'}" + ) + + hover_text.append( + f"Base Reaction: {base_reaction}
" + + f"Extended Reactions: {ext_info}
" + + f"Oscillation Frequency: {row['oscillation_frequency']:.6f}
" + + f"Oscillation Times: {row['oscillation_times']}
" + + f"Log(Pos/Neg Ratio): {row['log_pos_neg_ratio']:.4f}
" + + f"Positive Ratio: {row['positive_ratio']:.4f}
" + + f"Negative Ratio: {row['negative_ratio']:.4f}
" + + f"Zero Ratio: {row['zero_ratio']:.4f}
" + + f"Max |Net Flux|: {row['max_abs_flux']:.2e}
" + + f"Min Net Flux: {row['min_flux']:.2e}
" + + f"Max Net Flux: {row['max_flux']:.2e}
" + + f"Non-zero Points: {row['non_zero_points']}/{row['n_timepoints']}
" + + f"Point Density: {density[idx]:.6e}" + ) + + # Build subplots: top density (x), right density (y), main scatter bottom-left + fig = make_subplots( + rows=2, + cols=2, + column_widths=[0.9, 0.1], + row_heights=[0.1, 0.9], + specs=[ + [{"secondary_y": False}, {"secondary_y": False}], + [{"secondary_y": False}, {"secondary_y": False}], + ], + vertical_spacing=0.05, + horizontal_spacing=0.05, + ) + + # Main scatter (row=2, col=1) + fig.add_trace( + go.Scatter( + x=x, + y=y, + mode="markers", + marker=dict( + size=8, + color=density, + colorscale="Viridis", + opacity=0.85, + colorbar=dict( + title=dict(text="Point Density", font=dict(size=12)), + tickfont=dict(size=11), + thickness=15, + len=0.7, + x=1.02, + ), + line=dict(width=0.5, color="white"), + ), + text=hover_text, + hovertemplate="%{text}", + name="Oscillating Reactions", + showlegend=False, + ), + row=2, + col=1, + ) + + # Top density (x distribution) + if len(x) > 1: + x_range = np.linspace(x.min(), x.max(), 250) + try: + x_density = gaussian_kde(x) + x_density_values = x_density(x_range) + except Exception: + x_density_values = np.ones_like(x_range) + else: + x_range = np.array([x.iloc[0]]) + x_density_values = np.array([1.0]) + + fig.add_trace( + go.Scatter( + x=x_range, + y=x_density_values, + mode="lines", + line=dict(color="darkorange", width=3), + fill="tozeroy", + fillcolor="rgba(255, 140, 0, 0.25)", + showlegend=False, + ), + row=1, + col=1, + ) + + # Right density (y distribution) + if len(y) > 1: + y_range = np.linspace(y.min(), y.max(), 250) + try: + y_density = gaussian_kde(y) + y_density_values = y_density(y_range) + except Exception: + y_density_values = np.ones_like(y_range) + else: + y_range = np.array([y.iloc[0]]) + y_density_values = np.array([1.0]) + + fig.add_trace( + go.Scatter( + x=y_density_values, + y=y_range, + mode="lines", + line=dict(color="darkgreen", width=3), + fill="tozerox", + fillcolor="rgba(0, 100, 0, 0.25)", + showlegend=False, + ), + row=2, + col=2, + ) + + # Layout and axes + fig.update_layout( + title=dict( + text="Oscillating Base Reactions: Frequency vs Positive/Negative Bias (semi-log)", + font=dict(size=18), + x=0.5, + ), + plot_bgcolor="white", + paper_bgcolor="white", + font=dict(family="Arial", size=12), + width=1000, + height=800, + margin=dict(l=80, r=120, t=120, b=80), + ) + + fig.update_xaxes( + title=dict( + text="Oscillation Frequency (oscillations per time step)", + font=dict(size=14), + ), + row=2, + col=1, + gridcolor="rgba(128,128,128,0.2)", + gridwidth=1, + ) + fig.update_yaxes( + title=dict(text="log₁₀(Positive Ratio / Negative Ratio)", font=dict(size=14)), + row=2, + col=1, + gridcolor="rgba(128,128,128,0.2)", + gridwidth=1, + ) + + # Add hline at y=0 (equal pos/neg) + fig.add_hline( + y=0, + line=dict(color="red", width=1, dash="dash"), + annotation=dict(text="Equal Pos/Neg Bias", font=dict(size=10, color="red")), + row=2, + col=1, + ) + + # Hide top-right subplot (unused) + fig.update_xaxes(visible=False, row=1, col=2) + fig.update_yaxes(visible=False, row=1, col=2) + + # Remove ticks for density subplots + fig.update_xaxes(showticklabels=False, row=1, col=1) + fig.update_yaxes(showticklabels=False, row=1, col=1) + fig.update_xaxes(showticklabels=False, row=2, col=2) + fig.update_yaxes(showticklabels=False, row=2, col=2) + + # Stats annotation + stats_text = ( + f"Oscillating Base Reactions: {len(oscillating_metrics):,}
" + + f"ε = {eps:.0e}
" + + f"Oscillation Frequency Range: {oscillating_metrics['oscillation_frequency'].min():.4f} to {oscillating_metrics['oscillation_frequency'].max():.4f}
" + + f"Log(Pos/Neg Ratio) Range: {oscillating_metrics['log_pos_neg_ratio'].min():.4f} to {oscillating_metrics['log_pos_neg_ratio'].max():.4f}
" + + f"Oscillation Times Range: {oscillating_metrics['oscillation_times'].min()} to {oscillating_metrics['oscillation_times'].max()}
" + + f"Max |Net Flux| Range: {oscillating_metrics['max_abs_flux'].min():.2e} to {oscillating_metrics['max_abs_flux'].max():.2e}" + ) + if len(oscillating_metrics) > 1: + stats_text += f"
Density Range: {density.min():.2e} - {density.max():.2e}" + + fig.add_annotation( + x=0.42, + y=0.78, + xref="paper", + yref="paper", + text=stats_text, + showarrow=False, + font=dict(size=11, color="black"), + bgcolor="rgba(255,255,255,0.95)", + bordercolor="rgba(128,128,128,0.5)", + borderwidth=1, + borderpad=8, + xanchor="left", + yanchor="top", + ) + + # Save plot + filename = os.path.join(outdir, "oscillating_reactions_dynamics_analysis.html") + os.makedirs(outdir, exist_ok=True) + fig.write_html(filename) + print("\n[INFO] Oscillating reactions dynamics plot saved:") + print(f" - {filename}") + + +def print_oscillating_reaction_summaries(oscillating_metrics: pd.DataFrame): + """Print top summaries for oscillating reactions.""" + if oscillating_metrics is None or len(oscillating_metrics) == 0: + print("[INFO] No oscillating reactions to summarize.") + return + + print( + "\n[INFO] Oscillating Base Reactions - Top 10 most dynamic (highest oscillation frequency):" + ) + top_dynamic = oscillating_metrics.sort_values( + "oscillation_frequency", ascending=False + ).head(10) + for _, row in top_dynamic.iterrows(): + print( + f" {row['base_reaction']}: freq={row['oscillation_frequency']:.6f}, " + f"times={row['oscillation_times']}, pos_ratio={row['positive_ratio']:.3f}, " + f"neg_ratio={row['negative_ratio']:.3f}, log_ratio={row['log_pos_neg_ratio']:.3f}" + ) + + print( + "\n[INFO] Oscillating Base Reactions - Top 10 most positive-biased (highest log pos/neg ratio):" + ) + top_positive_biased = oscillating_metrics.sort_values( + "log_pos_neg_ratio", ascending=False + ).head(10) + for _, row in top_positive_biased.iterrows(): + print( + f" {row['base_reaction']}: log_ratio={row['log_pos_neg_ratio']:.3f}, " + f"pos_ratio={row['positive_ratio']:.3f}, neg_ratio={row['negative_ratio']:.3f}, " + f"freq={row['oscillation_frequency']:.6f}" + ) + + print( + "\n[INFO] Oscillating Base Reactions - Top 10 most negative-biased (lowest log pos/neg ratio):" + ) + top_negative_biased = oscillating_metrics.sort_values( + "log_pos_neg_ratio", ascending=True + ).head(10) + for _, row in top_negative_biased.iterrows(): + print( + f" {row['base_reaction']}: log_ratio={row['log_pos_neg_ratio']:.3f}, " + f"pos_ratio={row['positive_ratio']:.3f}, neg_ratio={row['negative_ratio']:.3f}, " + f"freq={row['oscillation_frequency']:.6f}" + ) + + +def save_oscillating_results( + oscillating_metrics: pd.DataFrame, + base_reaction_details: dict, + base_reaction_mapping: dict, + metadata: dict, + outdir: str, +): + """Save oscillating reaction results to CSV files with experiment metadata in outdir.""" + os.makedirs(outdir, exist_ok=True) + prefix = "oscillating_reaction_analysis" + + # Save metrics + metrics_filename = os.path.join(outdir, f"{prefix}_dynamics_metrics.csv") + oscillating_metrics.to_csv(metrics_filename, index=False) + print( + f"\n[INFO] Oscillating reaction dynamics metrics saved to '{metrics_filename}'" + ) + + # Save mapping for oscillating reactions + oscillating_reactions = ( + set(oscillating_metrics["base_reaction"].tolist()) + if len(oscillating_metrics) > 0 + else set() + ) + mapping_data = [] + for base_reaction, info in base_reaction_mapping.items(): + if base_reaction in oscillating_reactions: + mapping_data.append( + { + "base_reaction": base_reaction, + "forward_extended": "; ".join(info["forward_extended"]), + "reverse_extended": "; ".join(info["reverse_extended"]), + "n_forward_extended": len(info["forward_extended"]), + "n_reverse_extended": len(info["reverse_extended"]), + "total_extended": len(info["all_extended"]), + } + ) + + if mapping_data: + mapping_df = pd.DataFrame(mapping_data) + mapping_filename = os.path.join(outdir, f"{prefix}_extended_mapping.csv") + mapping_df.to_csv(mapping_filename, index=False) + print( + f"[INFO] Oscillating reaction extended mapping saved to '{mapping_filename}'" + ) + else: + print("[INFO] No oscillating mapping data to save.") + + # Save metadata (only simple types) + metadata_filename = os.path.join(outdir, f"{prefix}_metadata.csv") + metadata_for_csv = { + k: v for k, v in metadata.items() if not isinstance(v, (dict, list, np.ndarray)) + } + metadata_for_csv["n_oscillating_reactions"] = ( + len(oscillating_metrics) if oscillating_metrics is not None else 0 + ) + metadata_df = pd.DataFrame([metadata_for_csv]) + metadata_df.to_csv(metadata_filename, index=False) + print(f"[INFO] Oscillating analysis metadata saved to '{metadata_filename}'") + + +def plot( + params: dict[str, Any], + conn: DuckDBPyConnection, + history_sql: str, + config_sql: str, + success_sql: str, # kept for compatibility though not used in this focused analysis + sim_data_dict: dict[str, dict[int, str]], + validation_data_paths: list[str], + outdir: str, + variant_metadata: dict[str, dict[int, Any]], + variant_names: dict[str, str], +): + """ + Main entry point: loads data via DuckDB, computes base net fluxes, finds oscillating reactions, + computes dynamic metrics for oscillating reactions, produces the semi-log heat scatter plot, + and saves results to outdir. + + params keys: + - eps: tolerance for zero comparisons (default 1e-30) + """ + eps = params.get("eps", 1e-30) + + print("[INFO] Starting oscillating base reaction analysis...") + print(f"[INFO] Parameters: eps={eps}") + print(f"[INFO] Output directory: {outdir}") + + try: + # Load data + df, metadata = load_fba_data(conn, history_sql, config_sql, sim_data_dict) + + # Extract extended fluxes (drop 'time') + if "time" in df.columns: + extended_flux_df = df.drop(columns=["time"]) + else: + extended_flux_df = df.copy() + + print( + f"[INFO] Loaded extended flux matrix: {extended_flux_df.shape[1]} extended reactions, {extended_flux_df.shape[0]} time points" + ) + + # Map extended -> base reactions + base_reaction_mapping, extended_to_base_map = map_extended_to_base_reactions( + extended_flux_df.columns.tolist(), metadata["base_to_extended_mapping"] + ) + + # Compute base net fluxes + base_flux_df, base_reaction_details = compute_base_reaction_fluxes( + extended_flux_df, base_reaction_mapping + ) + + # Categorize to find oscillating + _, _, oscillating, _, base_reaction_categories = ( + categorize_base_reactions_by_flux_behavior(base_flux_df, eps) + ) + + if len(oscillating) == 0: + print("[INFO] No oscillating base reactions detected. Exiting.") + return None + + # Compute oscillating dynamics metrics + oscillating_metrics = compute_oscillating_dynamics_metrics( + base_flux_df, oscillating, eps + ) + + # Print oscillating summaries + print_oscillating_reaction_summaries(oscillating_metrics) + + # Create plot + create_oscillating_dynamics_plot( + oscillating_metrics, base_reaction_details, outdir, eps + ) + + # Save oscillating results + save_oscillating_results( + oscillating_metrics, + base_reaction_details, + base_reaction_mapping, + metadata, + outdir, + ) + + print("\n[INFO] Oscillating base reaction analysis complete.") + print(f"[INFO] All files saved to directory: {outdir}") + + return { + "oscillating_metrics": oscillating_metrics, + "base_reaction_details": base_reaction_details, + "base_reaction_mapping": base_reaction_mapping, + "metadata": metadata, + } + + except Exception as e: + print(f"[ERROR] Analysis failed: {str(e)}") + import traceback + + traceback.print_exc() + return None diff --git a/ecoli/analysis/multigeneration/fba_generation_average_to_csv.py b/ecoli/analysis/multigeneration/fba_generation_average_to_csv.py new file mode 100644 index 000000000..f3b9bdd07 --- /dev/null +++ b/ecoli/analysis/multigeneration/fba_generation_average_to_csv.py @@ -0,0 +1,224 @@ +""" +Export FBA reaction net fluxes to CSV format with generation-wise averages. + +For each specified BioCyc_ID, creates a CSV where: +- Each row represents a BioCyc reaction ID +- Each column represents the average flux value for generation i +- Final column contains the overall average across all time steps + +Usage parameters: + "fba_flux_csv": { + # Required: specify BioCyc reaction IDs to export + "BioCyc_ID": ["Name1", "Name2", ...], + # Optional: specify generations to include + # If not specified, all generations will be used + "generation": [1, 2, ...], + # Optional: output filename (default: "fba_generation_average_summary.csv") + "output_filename": "custom_filename.csv" + } +""" + +import os +import pandas as pd +import polars as pl +from typing import Any +from duckdb import DuckDBPyConnection + +from ecoli.library.parquet_emitter import field_metadata +from ecoli.analysis.utils import ( + create_base_to_extended_mapping, + build_flux_calculation_sql, +) + + +def plot( + params: dict[str, Any], + conn: DuckDBPyConnection, + history_sql: str, + config_sql: str, + success_sql: str, + sim_data_dict: dict[str, dict[int, str]], + validation_data_paths: list[str], + outdir: str, + variant_metadata: dict[str, dict[int, Any]], + variant_names: dict[str, str], +): + """Export FBA reaction net fluxes to CSV with generation-wise averages.""" + + # Get parameters + biocyc_ids = params.get("BioCyc_ID", []) + if not biocyc_ids: + print( + "[ERROR] No BioCyc_ID found in params. Please specify reaction IDs to export." + ) + return None + + if isinstance(biocyc_ids, str): + biocyc_ids = [biocyc_ids] + + output_filename = params.get( + "output_filename", "fba_generation_average_summary.csv" + ) + target_generations = params.get("generation", None) + + print(f"[INFO] Exporting net fluxes for {len(biocyc_ids)} reactions: {biocyc_ids}") + if target_generations: + print(f"[INFO] Filtering for generations: {target_generations}") + + # Create base to extended reaction mapping + base_to_extended_mapping = create_base_to_extended_mapping(sim_data_dict) + if not base_to_extended_mapping: + print("[ERROR] Could not create base to extended reaction mapping") + return None + + # Load reaction IDs from config + try: + all_reaction_ids = field_metadata( + conn, config_sql, "listeners__fba_results__reaction_fluxes" + ) + print(f"[INFO] Total reactions in sim_data: {len(all_reaction_ids)}") + except Exception as e: + print(f"[ERROR] Error loading reaction IDs: {e}") + return None + + # Build SQL query for efficient flux calculation + flux_calculation_sql, valid_biocyc_ids = build_flux_calculation_sql( + biocyc_ids, base_to_extended_mapping, all_reaction_ids, history_sql + ) + + if not flux_calculation_sql or not valid_biocyc_ids: + print("[ERROR] Could not build flux calculation SQL") + return None + + print(f"[INFO] Processing {len(valid_biocyc_ids)} valid BioCyc IDs") + + # Execute the optimized SQL query + try: + df = conn.sql(flux_calculation_sql).pl() + print(f"[INFO] Loaded data with {df.height} time steps") + except Exception as e: + print(f"[ERROR] Error executing flux calculation SQL: {e}") + return None + + if df.is_empty(): + print("[ERROR] No data found") + return None + + # Filter by specified generations if provided + if target_generations is not None: + print(f"[INFO] Filtering for generations: {target_generations}") + df = df.filter(pl.col("generation").is_in(target_generations)) + + # Get unique generations and sort them + unique_generations = sorted(df["generation"].unique().to_list()) + print(f"[INFO] Found {len(unique_generations)} generations: {unique_generations}") + + # Calculate averages for each BioCyc ID + csv_data = calculate_generation_averages(df, valid_biocyc_ids, unique_generations) + + if csv_data is None or csv_data.empty: + print("[ERROR] No valid data to export") + return None + + # Save to CSV + output_path = os.path.join(outdir, output_filename) + csv_data.to_csv(output_path, index=False) + print(f"[INFO] Successfully exported flux data to: {output_path}") + print( + f"[INFO] CSV contains {len(csv_data)} rows and {len(csv_data.columns)} columns" + ) + + return csv_data + + +def calculate_generation_averages(df, valid_biocyc_ids, unique_generations): + """ + Calculate average flux values for each generation and overall average. + + Returns a pandas DataFrame with: + - BioCyc_ID column + - One column per generation (Gen_1_Avg, Gen_2_Avg, etc.) + - Overall_Average column + """ + + results = [] + + for biocyc_id in valid_biocyc_ids: + net_flux_col = f"{biocyc_id}_net_flux" + + if net_flux_col not in df.columns: + print(f"[WARNING] Column {net_flux_col} not found in dataframe") + continue + + # Initialize row data + row_data = {"BioCyc_ID": biocyc_id} + + # Calculate average for each generation + generation_averages = [] + for gen in unique_generations: + gen_data = df.filter(pl.col("generation") == gen) + if gen_data.height > 0: + gen_avg = gen_data[net_flux_col].mean() + if gen_avg is not None: + row_data[f"Gen_{gen}_Avg"] = gen_avg + generation_averages.append(gen_avg) + else: + row_data[f"Gen_{gen}_Avg"] = 0.0 + else: + row_data[f"Gen_{gen}_Avg"] = 0.0 + + # Calculate overall average across all time steps + overall_avg = df[net_flux_col].mean() + row_data["Overall_Average"] = overall_avg if overall_avg is not None else 0.0 + + results.append(row_data) + + if not results: + print("[ERROR] No valid results calculated") + return None + + # Convert to pandas DataFrame for easy CSV export + csv_data = pd.DataFrame(results) + + # Reorder columns: BioCyc_ID first, then generation columns in order, then Overall_Average + column_order = ["BioCyc_ID"] + for gen in unique_generations: + column_order.append(f"Gen_{gen}_Avg") + column_order.append("Overall_Average") + + csv_data = csv_data[column_order] + + return csv_data + + +def validate_parameters(params): + """Validate input parameters.""" + + if not isinstance(params, dict): + return False, "Parameters must be a dictionary" + + biocyc_ids = params.get("BioCyc_ID", []) + if not biocyc_ids: + return False, "BioCyc_ID parameter is required" + + if isinstance(biocyc_ids, str): + biocyc_ids = [biocyc_ids] + + if not isinstance(biocyc_ids, list): + return False, "BioCyc_ID must be a string or list of strings" + + target_generations = params.get("generation", None) + if target_generations is not None: + if not isinstance(target_generations, list): + return False, "generation parameter must be a list of integers" + if not all(isinstance(x, int) for x in target_generations): + return False, "All generation values must be integers" + + output_filename = params.get("output_filename", "fba_flux_summary.csv") + if not isinstance(output_filename, str): + return False, "output_filename must be a string" + + if not output_filename.endswith(".csv"): + return False, "output_filename must end with .csv" + + return True, "Parameters are valid" diff --git a/ecoli/analysis/multigeneration/fba_heat_scatter.py b/ecoli/analysis/multigeneration/fba_heat_scatter.py new file mode 100644 index 000000000..087afcfb5 --- /dev/null +++ b/ecoli/analysis/multigeneration/fba_heat_scatter.py @@ -0,0 +1,1372 @@ +""" +Base reaction flux analysis script with per-generation common-category extraction and burst detection. + +This script contains the original helper functions for loading FBA data, mapping +extended -> base reactions, computing base reaction net fluxes, categorizing base +reactions by flux behavior (always_positive, always_negative, oscillating, +always_zero), plotting utilities, and saving results. + +Added functionality: +- For each generation present in the history SQL, compute base reaction categories + using the exact same categorization logic. +- Compute the intersection (common base reactions) across all generations for + each category. +- Save those common reaction names into four separate CSV files in outdir: + - common_base_reactions_always_zero.csv + - common_base_reactions_always_positive.csv + - common_base_reactions_always_negative.csv + - common_base_reactions_oscillating.csv +- Compute "burst" base reactions for categories always_positive, always_negative, and oscillating. + A base reaction is considered a burst (for a category) if, for every generation, + when that reaction belongs to that category in that generation, its zero_ratio + (fraction of timepoints with essentially zero flux) is > burst_threshold. + The default burst_threshold is 0.1 (configurable via params["burst_threshold"]). +- Save three CSVs (one-column each) listing burst base reaction names: + - burst_base_reaction_always_positive.csv + - burst_base_reaction_always_negative.csv + - burst_base_reaction_oscillating.csv + +Important constraints preserved: +- The data loading function `load_fba_data` is left intact (not modified). +- The category determination logic (in `categorize_base_reactions_by_flux_behavior`) + is not altered. +- `plot(...)` remains present and acts as the main entry point. + +Usage: + Call plot(...) with the same parameters expected by the original script. +""" + +import os +from typing import Any +import pandas as pd +import numpy as np +import plotly.graph_objects as go +from plotly.subplots import make_subplots +from scipy.stats import gaussian_kde +from duckdb import DuckDBPyConnection + +# --- BEGIN: Existing helper imports and functions (kept exactly as requested) --- +# Note: These functions are identical to the code you provided. I have not changed +# the internals of load_fba_data, or the categorization logic. +from ecoli.library.parquet_emitter import field_metadata +from ecoli.analysis.utils import create_base_to_extended_mapping + + +def load_fba_data( + conn: DuckDBPyConnection, history_sql: str, config_sql: str, sim_data_dict: dict +) -> tuple[pd.DataFrame, dict]: + """ + Load FBA flux data using DuckDB connection and SQL queries. + + Parameters: + - conn: DuckDB connection + - history_sql: SQL query for historical data + - config_sql: SQL query for configuration data + - sim_data_dict: Dictionary with sim_data information + + Returns: + - df: DataFrame with flux data (time points x extended reactions) + - metadata: Dictionary with experiment metadata + """ + print("[INFO] Loading FBA flux data via SQL...") + + try: + # Create base to extended reaction mapping + base_to_extended_mapping = create_base_to_extended_mapping(sim_data_dict) + if not base_to_extended_mapping: + raise Exception("Could not create base to extended reaction mapping") + + # Load the reaction IDs from the config - this is the array that maps to flux matrix columns + reaction_ids = field_metadata( + conn, config_sql, "listeners__fba_results__reaction_fluxes" + ) + print(f"[INFO] Total reactions in sim_data: {len(reaction_ids)}") + + # Required columns for the query + required_columns = [ + "time", + "generation", + "listeners__fba_results__reaction_fluxes", + ] + + # Build SQL query + sql = f""" + SELECT {", ".join(required_columns)} + FROM ({history_sql}) + ORDER BY generation, time + """ + + # Execute query + df_pl = conn.sql(sql).pl() + + if df_pl.is_empty(): + raise Exception("No data found") + print(f"[INFO] Loaded data with {df_pl.height} time steps") + + # Extract flux matrix and convert to pandas DataFrame + flux_matrix = df_pl["listeners__fba_results__reaction_fluxes"].to_numpy() + flux_matrix = np.array([np.array(row) for row in flux_matrix]) + + # Create DataFrame with extended reactions + df = pd.DataFrame(flux_matrix, columns=reaction_ids) + + # Add time information + time_data = df_pl.select(["time"]).to_pandas() + df = pd.concat([time_data, df], axis=1) + + # Drop initial time point where time == 0 + df = df[df["time"] != 0].reset_index(drop=True) + + # Create metadata dictionary + metadata = { + "n_extended_reactions": len(reaction_ids), + "n_timepoints": len(df), + "extended_reaction_names": reaction_ids, + "base_to_extended_mapping": base_to_extended_mapping, + } + + print("[INFO] Successfully loaded data:") + print(f" - Time points: {len(df)}") + print(f" - Extended reactions: {len(reaction_ids)}") + print(f" - Base reaction mapping entries: {len(base_to_extended_mapping)}") + + return df, metadata + + except Exception as e: + print(f"[ERROR] Failed to load data: {str(e)}") + raise + + +def map_extended_to_base_reactions(extended_reactions, base_to_extended_mapping): + """ + Map extended reactions to base reactions and identify forward/reverse relationships. + + Parameters: + - extended_reactions: List of extended reaction names + - base_to_extended_mapping: Dict mapping base reaction ID to list of extended reaction names + + Returns: + - base_reaction_mapping: Dict with base reaction info including forward/reverse extended reactions + - extended_to_base_map: Dict mapping each extended reaction to its base reaction + """ + # Create reverse mapping from extended to base + extended_to_base_map = {} + for base_rxn, extended_list in base_to_extended_mapping.items(): + for extended_rxn in extended_list: + extended_to_base_map[extended_rxn] = base_rxn + + base_reaction_mapping = {} + + for extended_reaction in extended_reactions: + # Get base reaction name from mapping + base_reaction = extended_to_base_map.get(extended_reaction) + + if base_reaction is None: + print( + f"[WARNING] No base reaction found for extended reaction: {extended_reaction}" + ) + continue + + # Initialize base reaction entry if not exists + if base_reaction not in base_reaction_mapping: + base_reaction_mapping[base_reaction] = { + "forward_extended": [], + "reverse_extended": [], + "all_extended": [], + } + + # Determine if this is a forward or reverse extended reaction + if extended_reaction.endswith(" (reverse)"): + base_reaction_mapping[base_reaction]["reverse_extended"].append( + extended_reaction + ) + else: + base_reaction_mapping[base_reaction]["forward_extended"].append( + extended_reaction + ) + + base_reaction_mapping[base_reaction]["all_extended"].append(extended_reaction) + + print("[INFO] Base reaction mapping results:") + print(f" - Total base reactions: {len(base_reaction_mapping)}") + print(f" - Extended reactions mapped: {len(extended_to_base_map)}") + + # Print statistics about forward/reverse distributions + forward_only = sum( + 1 + for info in base_reaction_mapping.values() + if len(info["forward_extended"]) > 0 and len(info["reverse_extended"]) == 0 + ) + reverse_only = sum( + 1 + for info in base_reaction_mapping.values() + if len(info["forward_extended"]) == 0 and len(info["reverse_extended"]) > 0 + ) + both_directions = sum( + 1 + for info in base_reaction_mapping.values() + if len(info["forward_extended"]) > 0 and len(info["reverse_extended"]) > 0 + ) + + print(f" - Base reactions with forward extended only: {forward_only}") + print(f" - Base reactions with reverse extended only: {reverse_only}") + print( + f" - Base reactions with both forward and reverse extended: {both_directions}" + ) + + return base_reaction_mapping, extended_to_base_map + + +def compute_base_reaction_fluxes(flux_df, base_reaction_mapping): + """ + Compute base reaction fluxes by summing forward extended and subtracting reverse extended fluxes. + + Parameters: + - flux_df: DataFrame with extended reaction flux values + - base_reaction_mapping: Dict with base reaction info + + Returns: + - base_flux_df: DataFrame with base reaction net flux values + - base_reaction_details: Dict with detailed info about each base reaction + """ + base_flux_data = {} + base_reaction_details = {} + + for base_reaction, info in base_reaction_mapping.items(): + forward_extended = info["forward_extended"] + reverse_extended = info["reverse_extended"] + + # Sum forward extended fluxes + forward_flux = pd.Series(0.0, index=flux_df.index) + if forward_extended: + for ext_reaction in forward_extended: + if ext_reaction in flux_df.columns: + forward_flux += flux_df[ext_reaction] + + # Sum reverse extended fluxes + reverse_flux = pd.Series(0.0, index=flux_df.index) + if reverse_extended: + for ext_reaction in reverse_extended: + if ext_reaction in flux_df.columns: + reverse_flux += flux_df[ext_reaction] + + # Net flux = forward - reverse + net_flux = forward_flux - reverse_flux + base_flux_data[base_reaction] = net_flux + + # Store details for analysis + base_reaction_details[base_reaction] = { + "forward_extended": forward_extended, + "reverse_extended": reverse_extended, + "n_forward_extended": len(forward_extended), + "n_reverse_extended": len(reverse_extended), + "total_extended": len(info["all_extended"]), + } + + base_flux_df = pd.DataFrame(base_flux_data) + + print("[INFO] Base reaction flux computation results:") + print(f" - Base reactions computed: {len(base_flux_df.columns)}") + print(f" - Time points: {len(base_flux_df)}") + + return base_flux_df, base_reaction_details + + +def categorize_base_reactions_by_flux_behavior(base_flux_df, eps=1e-30): + """ + Categorize base reactions based on their flux behavior across time steps. + + Parameters: + - base_flux_df: DataFrame with base reaction net flux values + - eps: Small tolerance for zero comparison + + Returns: + - always_positive: List of base reactions that are always >= 0 and have max > 0 + - always_negative: List of base reactions that are always <= 0 and have max abs > 0 + - oscillating: List of base reactions that change sign + - always_zero: List of base reactions that are always zero + - base_reaction_categories: Dictionary with detailed categorization info + """ + always_positive = [] + always_negative = [] + oscillating = [] + always_zero = [] + base_reaction_categories = {} + + for base_reaction in base_flux_df.columns: + flux_values = base_flux_df[base_reaction].values + + # Check for positive, negative, and zero values + has_positive = np.any(flux_values > eps) + has_negative = np.any(flux_values < -eps) + has_zero = np.any(np.abs(flux_values) <= eps) + + min_flux = flux_values.min() + max_flux = flux_values.max() + max_abs_flux = np.max(np.abs(flux_values)) + + # Categorize based on behavior + if max_abs_flux <= eps: # All values are essentially zero + always_zero.append(base_reaction) + category = "always_zero" + elif not has_negative and has_positive: # All values >= -eps and has some > eps + always_positive.append(base_reaction) + category = "always_positive" + elif not has_positive and has_negative: # All values <= eps and has some < -eps + always_negative.append(base_reaction) + category = "always_negative" + elif has_positive and has_negative: # Has both positive and negative values + oscillating.append(base_reaction) + category = "oscillating" + else: + # This case should be covered by always_zero, but keep as safety net + always_zero.append(base_reaction) + category = "always_zero" + + base_reaction_categories[base_reaction] = { + "category": category, + "min_flux": min_flux, + "max_flux": max_flux, + "max_abs_flux": max_abs_flux, + "has_positive": has_positive, + "has_negative": has_negative, + "has_zero": has_zero, + } + + print("\n[INFO] Base reaction categorization by flux behavior:") + print(f" - Always positive (>= 0, max > 0): {len(always_positive)}") + print(f" - Always negative (<= 0, max abs > 0): {len(always_negative)}") + print(f" - Oscillating (changes sign): {len(oscillating)}") + print(f" - Always zero (max abs ≈ 0): {len(always_zero)}") + + return ( + always_positive, + always_negative, + oscillating, + always_zero, + base_reaction_categories, + ) + + +def print_base_reaction_category_summaries( + positive_df, negative_df, oscillating_df, always_zero_df +): + """Print summary information for each base reaction category.""" + + if len(positive_df) > 0: + print( + "\n[INFO] Always Positive Base Reactions (max flux > 0) - Top 5 most active (lowest zero ratio):" + ) + top_positive = positive_df.sort_values("zero_ratio").head(5) + for _, row in top_positive.iterrows(): + print( + f" {row['base_reaction']}: zero_ratio={row['zero_ratio']:.4f}, max_flux={row['max_flux']:.2e}, ext_reactions={row['total_extended']}" + ) + + if len(negative_df) > 0: + print( + "\n[INFO] Always Negative Base Reactions (max abs flux > 0) - Top 5 most active (lowest zero ratio):" + ) + top_negative = negative_df.sort_values("zero_ratio").head(5) + for _, row in top_negative.iterrows(): + print( + f" {row['base_reaction']}: zero_ratio={row['zero_ratio']:.4f}, max_abs_flux={row['max_abs_flux']:.2e}, ext_reactions={row['total_extended']}" + ) + + if len(oscillating_df) > 0: + print( + "\n[INFO] Oscillating Base Reactions - Top 5 most active (lowest zero ratio):" + ) + top_oscillating = oscillating_df.sort_values("zero_ratio").head(5) + for _, row in top_oscillating.iterrows(): + print( + f" {row['base_reaction']}: zero_ratio={row['zero_ratio']:.4f}, max_abs_flux={row['max_abs_flux']:.2e}, ext_reactions={row['total_extended']}" + ) + + if len(always_zero_df) > 0: + print("\n[INFO] Always Zero Base Reactions - First 5 examples:") + first_zero = always_zero_df.head(5) + for _, row in first_zero.iterrows(): + print( + f" {row['base_reaction']}: max_abs_flux={row['max_abs_flux']:.2e}, ext_reactions={row['total_extended']}" + ) + + +def create_base_reaction_flux_plots( + flux_data, direction, filename_suffix, epsilon_log, base_reaction_details, outdir +): + """Create heat scatter and simple scatter plots for base reactions in a given flux direction.""" + + if len(flux_data) == 0: + return + + # Prepare data for plotting + x = flux_data["log_max_abs_flux"] + y = flux_data["zero_ratio"] + + # Calculate point density using gaussian_kde for heat scatter plot + if len(flux_data) > 1: # Need at least 2 points for KDE + xy = np.vstack([x, y]) + density = gaussian_kde(xy)(xy) + else: + density = np.array([1.0]) # Single point gets density of 1 + + # Create hover text with base reaction information + hover_text = [] + simple_hover_text = [] + + for idx, row in flux_data.iterrows(): + base_reaction = row["base_reaction"] + details = base_reaction_details.get(base_reaction, {}) + + # Create extended reaction info + forward_ext = details.get("forward_extended", []) + reverse_ext = details.get("reverse_extended", []) + + ext_info = f"Forward: {len(forward_ext)} extended, Reverse: {len(reverse_ext)} extended" + if len(forward_ext) <= 3: + ext_info += ( + f"
Forward: {', '.join(forward_ext) if forward_ext else 'None'}" + ) + if len(reverse_ext) <= 3: + ext_info += ( + f"
Reverse: {', '.join(reverse_ext) if reverse_ext else 'None'}" + ) + + # For heat scatter (with density) + hover_text.append( + f"Base Reaction: {base_reaction}
" + + f"Extended Reactions: {ext_info}
" + + f"Category: {row['category']}
" + + f"Zero Ratio: {row['zero_ratio']:.4f}
" + + f"Max |Net Flux|: {row['max_abs_flux']:.2e}
" + + f"Min Net Flux: {row['min_flux']:.2e}
" + + f"Max Net Flux: {row['max_flux']:.2e}
" + + f"Log |Max Net Flux|: {row['log_max_abs_flux']:.2f}
" + + f"Point Density: {density[list(flux_data.index).index(idx)]:.6f}" + ) + + # For simple scatter (without density) + simple_hover_text.append( + f"Base Reaction: {base_reaction}
" + + f"Extended Reactions: {ext_info}
" + + f"Category: {row['category']}
" + + f"Zero Ratio: {row['zero_ratio']:.4f}
" + + f"Max |Net Flux|: {row['max_abs_flux']:.2e}
" + + f"Min Net Flux: {row['min_flux']:.2e}
" + + f"Max Net Flux: {row['max_flux']:.2e}
" + + f"Log |Max Net Flux|: {row['log_max_abs_flux']:.2f}" + ) + + # 1. HEAT SCATTER PLOT WITH MARGINAL HISTOGRAMS + # Create subplot with marginal histograms + fig_heat = make_subplots( + rows=2, + cols=2, + column_widths=[0.9, 0.1], + row_heights=[0.1, 0.9], + specs=[ + [{"secondary_y": False}, {"secondary_y": False}], + [{"secondary_y": False}, {"secondary_y": False}], + ], + vertical_spacing=0.05, + horizontal_spacing=0.05, + subplot_titles=("", "", "", ""), + ) + + # Main scatter plot (bottom left, row=2, col=1) + fig_heat.add_trace( + go.Scatter( + x=x, + y=y, + mode="markers", + marker=dict( + size=8, + color=density, + colorscale="Plasma", + opacity=0.8, + colorbar=dict( + title=dict(text="Point Density", font=dict(size=14)), + tickfont=dict(size=12), + thickness=15, + len=0.7, + x=1.02, # Position colorbar to the right + ), + line=dict(width=0.5, color="white"), + ), + text=hover_text, + hovertemplate="%{text}", + name="Base Reactions", + showlegend=False, + ), + row=2, + col=1, + ) + + # Top density curve (x-axis distribution, row=1, col=1) + if len(x) > 1: + # Create smooth density curve for x-axis + x_range = np.linspace(x.min(), x.max(), 250) + x_density = gaussian_kde(x) + x_density_values = x_density(x_range) + else: + # Single point case + x_range = np.array([x.iloc[0]]) + x_density_values = np.array([1.0]) + + fig_heat.add_trace( + go.Scatter( + x=x_range, + y=x_density_values, + mode="lines", + line=dict(color="steelblue", width=3), + fill="tozeroy", + fillcolor="rgba(70, 130, 180, 0.3)", + name="X Density", + showlegend=False, + ), + row=1, + col=1, + ) + + # Right density curve (y-axis distribution, row=2, col=2) + if len(y) > 1: + # Create smooth density curve for y-axis + y_range = np.linspace(y.min(), y.max(), 250) + y_density = gaussian_kde(y) + y_density_values = y_density(y_range) + else: + # Single point case + y_range = np.array([y.iloc[0]]) + y_density_values = np.array([1.0]) + + fig_heat.add_trace( + go.Scatter( + x=y_density_values, + y=y_range, + mode="lines", + line=dict(color="lightcoral", width=3), + fill="tozerox", + fillcolor="rgba(240, 128, 128, 0.3)", + name="Y Density", + showlegend=False, + ), + row=2, + col=2, + ) + + # Update layout for heat scatter with histograms + fig_heat.update_layout( + title=dict( + text=f"Heat Scatter Plot with Marginal Density Curves: {direction} Base Reaction Net Flux", + font=dict(size=18), + x=0.5, + xanchor="center", + ), + plot_bgcolor="white", + paper_bgcolor="white", + font=dict(family="Arial", size=12), + width=1000, + height=800, + margin=dict(l=80, r=120, t=100, b=80), + ) + + # Update axes for main plot + fig_heat.update_xaxes( + title=dict(text="log₁₀(ε + |Max Net Flux|)", font=dict(size=14)), + tickfont=dict(size=12), + gridcolor="rgba(128,128,128,0.2)", + gridwidth=1, + row=2, + col=1, + ) + fig_heat.update_yaxes( + title=dict(text="Zero Flux Ratio", font=dict(size=14)), + tickfont=dict(size=12), + gridcolor="rgba(128,128,128,0.2)", + gridwidth=1, + row=2, + col=1, + ) + + # Update axes for histograms (remove tick labels and titles) + fig_heat.update_xaxes(showticklabels=False, title="", row=1, col=1) + fig_heat.update_yaxes(showticklabels=False, title="", row=1, col=1) + fig_heat.update_xaxes(showticklabels=False, title="", row=2, col=2) + fig_heat.update_yaxes(showticklabels=False, title="", row=2, col=2) + + # Hide the top-right subplot + fig_heat.update_xaxes(visible=False, row=1, col=2) + fig_heat.update_yaxes(visible=False, row=1, col=2) + + # Add statistics annotation + stats_text = ( + f"Base Reactions ({direction}): {len(flux_data):,}
" + + f"ε = {epsilon_log:.0e}
" + + f"|Max Net Flux| Range: {flux_data['max_abs_flux'].min():.2e} to {flux_data['max_abs_flux'].max():.2e}
" + + f"Zero Ratio Range: {flux_data['zero_ratio'].min():.4f} to {flux_data['zero_ratio'].max():.4f}
" + + f"Extended Reactions Range: {flux_data['total_extended'].min()} to {flux_data['total_extended'].max()}" + ) + + if len(flux_data) > 1: + stats_text += f"
Density Range: {density.min():.2e} - {density.max():.2e}" + + fig_heat.add_annotation( + x=0.02, + y=0.48, + xref="paper", + yref="paper", + text=stats_text, + showarrow=False, + font=dict(size=11, color="black"), + bgcolor="rgba(255,255,255,0.9)", + bordercolor="rgba(128,128,128,0.5)", + borderwidth=1, + borderpad=10, + xanchor="left", + yanchor="top", + ) + + # 2. SIMPLE SCATTER PLOT (original version without histograms) + fig_simple = go.Figure() + + fig_simple.add_trace( + go.Scatter( + x=x, + y=y, + mode="markers", + marker=dict( + size=6, + color="steelblue", + opacity=0.7, + line=dict(width=0.5, color="white"), + ), + text=simple_hover_text, + hovertemplate="%{text}", + name="Base Reactions", + ) + ) + + fig_simple.update_layout( + title=dict( + text=f"Simple Scatter Plot: {direction} Base Reaction Net Flux - Zero Ratio vs |Max Net Flux|", + font=dict(size=18), + x=0.5, + xanchor="center", + ), + xaxis=dict( + title=dict(text="log₁₀(ε + |Max Net Flux|)", font=dict(size=14)), + tickfont=dict(size=12), + gridcolor="rgba(128,128,128,0.2)", + gridwidth=1, + ), + yaxis=dict( + title=dict(text="Zero Flux Ratio", font=dict(size=14)), + tickfont=dict(size=12), + gridcolor="rgba(128,128,128,0.2)", + gridwidth=1, + ), + plot_bgcolor="white", + paper_bgcolor="white", + font=dict(family="Arial", size=12), + width=900, + height=700, + margin=dict(l=80, r=80, t=100, b=80), + showlegend=False, + ) + + # Add statistics for simple plot + simple_stats_text = ( + f"Base Reactions ({direction}): {len(flux_data):,}
" + + f"ε = {epsilon_log:.0e}
" + + f"|Max Net Flux| Range: {flux_data['max_abs_flux'].min():.2e} to {flux_data['max_abs_flux'].max():.2e}
" + + f"Zero Ratio Range: {flux_data['zero_ratio'].min():.4f} to {flux_data['zero_ratio'].max():.4f}
" + + f"Extended Reactions Range: {flux_data['total_extended'].min()} to {flux_data['total_extended'].max()}" + ) + + fig_simple.add_annotation( + x=0.02, + y=0.48, + xref="paper", + yref="paper", + text=simple_stats_text, + showarrow=False, + font=dict(size=11, color="black"), + bgcolor="rgba(255,255,255,0.9)", + bordercolor="rgba(128,128,128,0.5)", + borderwidth=1, + borderpad=10, + xanchor="left", + yanchor="top", + ) + + # Save plots + heat_filename = os.path.join( + outdir, + f"heat_scatter_with_density_curves_base_reactions_{filename_suffix}.html", + ) + simple_filename = os.path.join( + outdir, f"simple_scatter_plot_base_reactions_{filename_suffix}.html" + ) + + fig_heat.write_html(heat_filename) + fig_simple.write_html(simple_filename) + + print(f"\n[INFO] {direction} base reaction flux plots saved:") + print(f" - {heat_filename}") + print(f" - {simple_filename}") + + +def save_base_reaction_results( + comprehensive_df, + oscillating_df, + always_zero_df, + base_reaction_details, + base_reaction_mapping, + active_base_flux_df, + metadata, + outdir, +): + """Save all base reaction results to CSV files with experiment metadata in outdir.""" + + # Create output directory if it doesn't exist + os.makedirs(outdir, exist_ok=True) + + # Create filename prefix + prefix = "base_reaction_analysis" + + # Save comprehensive base reaction metrics + comprehensive_filename = os.path.join(outdir, f"{prefix}_metrics.csv") + comprehensive_df.to_csv(comprehensive_filename, index=False) + print( + f"\n[INFO] Comprehensive base reaction metrics saved to '{comprehensive_filename}'" + ) + + # Save oscillating base reactions specifically + if len(oscillating_df) > 0: + oscillating_filename = os.path.join(outdir, f"{prefix}_oscillating.csv") + oscillating_df.to_csv(oscillating_filename, index=False) + print(f"[INFO] Oscillating base reactions saved to '{oscillating_filename}'") + else: + print("[INFO] No oscillating base reactions found.") + + # Save always zero base reactions specifically + if len(always_zero_df) > 0: + zero_filename = os.path.join(outdir, f"{prefix}_always_zero.csv") + always_zero_df.to_csv(zero_filename, index=False) + print(f"[INFO] Always zero base reactions saved to '{zero_filename}'") + else: + print("[INFO] No always zero base reactions found.") + + # Save base reaction mapping details + mapping_data = [] + for base_reaction, info in base_reaction_mapping.items(): + mapping_data.append( + { + "base_reaction": base_reaction, + "forward_extended": "; ".join(info["forward_extended"]), + "reverse_extended": "; ".join(info["reverse_extended"]), + "n_forward_extended": len(info["forward_extended"]), + "n_reverse_extended": len(info["reverse_extended"]), + "total_extended": len(info["all_extended"]), + } + ) + + mapping_df = pd.DataFrame(mapping_data) + mapping_filename = os.path.join(outdir, f"{prefix}_extended_mapping.csv") + mapping_df.to_csv(mapping_filename, index=False) + print( + f"[INFO] Base reaction to extended reaction mapping saved to '{mapping_filename}'" + ) + + # Save filtered active base reaction flux data + flux_filename = os.path.join(outdir, f"{prefix}_filtered_flux.csv") + active_base_flux_df.to_csv(flux_filename, index=False, encoding="utf-8-sig") + print(f"[INFO] Filtered active base reactions saved to '{flux_filename}'") + + # Save metadata + metadata_filename = os.path.join(outdir, f"{prefix}_metadata.csv") + metadata_for_csv = { + k: v for k, v in metadata.items() if not isinstance(v, (dict, list, np.ndarray)) + } # Only save simple types + metadata_for_csv["n_base_reactions"] = len(comprehensive_df) + metadata_df = pd.DataFrame([metadata_for_csv]) + metadata_df.to_csv(metadata_filename, index=False) + print(f"[INFO] Experiment metadata saved to '{metadata_filename}'") + + # Print detailed summary statistics by category + print("\n[INFO] Detailed Summary Statistics by Category for Base Reactions:") + + for category in [ + "always_positive", + "always_negative", + "oscillating", + "always_zero", + ]: + cat_df = comprehensive_df[comprehensive_df["category"] == category] + if len(cat_df) > 0: + print( + f"\n {category.replace('_', ' ').title()} Base Reactions ({len(cat_df)}):" + ) + print( + f" Zero ratio range: {cat_df['zero_ratio'].min():.4f} - {cat_df['zero_ratio'].max():.4f}" + ) + print( + f" |Max net flux| range: {cat_df['max_abs_flux'].min():.2e} - {cat_df['max_abs_flux'].max():.2e}" + ) + print( + f" Min net flux range: {cat_df['min_flux'].min():.2e} - {cat_df['min_flux'].max():.2e}" + ) + print( + f" Max net flux range: {cat_df['max_flux'].min():.2e} - {cat_df['max_flux'].max():.2e}" + ) + print( + f" Extended reactions per base: {cat_df['total_extended'].min()} - {cat_df['total_extended'].max()}" + ) + + # Count reactions with different flux behaviors + has_zero_count = cat_df["has_zero"].sum() + print(f" Base reactions with zero flux points: {has_zero_count}") + + # Extended reaction statistics + total_forward_ext = cat_df["n_forward_extended"].sum() + total_reverse_ext = cat_df["n_reverse_extended"].sum() + print(f" Total forward extended reactions: {total_forward_ext}") + print(f" Total reverse extended reactions: {total_reverse_ext}") + + print(f"\nTotal base reactions: {len(comprehensive_df)}") + print(f"Total extended reactions mapped: {metadata['n_extended_reactions']}") + + +# --- END: Existing helper functions --- + +# --- BEGIN: New helpers & updated plot() that compute common base reaction names across generations and burst detection --- + + +def _get_distinct_generations(conn: DuckDBPyConnection, history_sql: str) -> list: + """ + Return a sorted list of distinct generation values from the history SQL. + + We treat history_sql as a subquery (it may already be a SELECT ...). + """ + gen_query = f"SELECT DISTINCT generation FROM ({history_sql}) ORDER BY generation" + gen_pl = conn.sql(gen_query).pl() + if gen_pl.is_empty(): + return [] + gen_df = gen_pl.to_pandas() + # Expect a column named 'generation' + generations = gen_df["generation"].tolist() + return generations + + +def _load_generation_flux_df( + conn: DuckDBPyConnection, history_sql: str, reaction_ids: list, generation: int +) -> pd.DataFrame: + """ + Load flux dataframe for a specific generation. + + Returns DataFrame where columns are: 'time' + reaction_ids (extended reaction columns). + Drops time==0 rows to match the main loader behavior. + """ + # Build SQL for this generation + required_columns = [ + "time", + "generation", + "listeners__fba_results__reaction_fluxes", + ] + sql = f""" + SELECT {", ".join(required_columns)} + FROM ({history_sql}) + WHERE generation = {generation} + ORDER BY time + """ + df_pl = conn.sql(sql).pl() + if df_pl.is_empty(): + return pd.DataFrame() # empty + + flux_matrix = df_pl["listeners__fba_results__reaction_fluxes"].to_numpy() + flux_matrix = np.array([np.array(row) for row in flux_matrix]) + df_ext = pd.DataFrame(flux_matrix, columns=reaction_ids) + + time_data = df_pl.select(["time"]).to_pandas() + df_full = pd.concat([time_data, df_ext], axis=1) + # Drop initial time point where time == 0 to remain consistent + df_full = df_full[df_full["time"] != 0].reset_index(drop=True) + return df_full + + +def plot( + params: dict[str, Any], + conn: DuckDBPyConnection, + history_sql: str, + config_sql: str, + success_sql: str, + sim_data_dict: dict[str, dict[int, str]], + validation_data_paths: list[str], + outdir: str, + variant_metadata: dict[str, dict[int, Any]], + variant_names: dict[str, str], +): + """ + Preprocesses FBA flux data by mapping extended reactions to base reactions, + computes net fluxes for base reactions (forward extended - reverse extended), + categorizes base reactions based on flux behavior, creates visualizations, + and additionally computes base reactions that are common across all generations + for each category (always_zero, always_positive, always_negative, oscillating). + + Also computes "burst" base reactions for categories always_positive, + always_negative, and oscillating: a base reaction is considered a burst if + in every generation the reaction (when in that category) has zero_ratio > + burst_threshold. + + Saves: + - Four CSVs with common reactions per category (one-column: base_reaction_name) + - Three CSVs with burst reactions per category (one-column: base_reaction_name) + + Returns the same tuple as the earlier design plus printing results. + """ + + # Get parameters with defaults + zero_threshold = params.get("zero_threshold", 0.999) + eps = params.get("eps", 1e-30) + epsilon_log = params.get("epsilon_log", 1e-30) + burst_threshold = params.get("burst_threshold", 0.1) + + print("[INFO] Starting base reaction flux analysis...") + print( + f"[INFO] Parameters: zero_threshold={zero_threshold}, eps={eps}, epsilon_log={epsilon_log}, burst_threshold={burst_threshold}" + ) + print(f"[INFO] Output directory: {outdir}") + + try: + # Load data (this uses the original loader - we DO NOT modify it) + df, metadata = load_fba_data(conn, history_sql, config_sql, sim_data_dict) + + # Separate flux data (drop time columns) + extended_flux_df = df.drop(columns=["time"]) + + print( + f"[INFO] Original data: {len(extended_flux_df.columns)} extended reactions, {len(extended_flux_df)} time points" + ) + + # Map extended reactions to base reactions using the full set + base_reaction_mapping, extended_to_base_map = map_extended_to_base_reactions( + extended_flux_df.columns.tolist(), metadata["base_to_extended_mapping"] + ) + + # Compute base reaction fluxes for the entire dataset + base_flux_df, base_reaction_details = compute_base_reaction_fluxes( + extended_flux_df, base_reaction_mapping + ) + + # Categorize base reactions by flux behavior for the entire dataset (keeps behavior logic unchanged) + ( + always_positive, + always_negative, + oscillating, + always_zero, + base_reaction_categories, + ) = categorize_base_reactions_by_flux_behavior(base_flux_df, eps) + + # Compute zero-flux ratio per base reaction (using absolute values) + zero_counts = (np.abs(base_flux_df) <= eps).sum(axis=0) + zero_ratio = zero_counts / base_flux_df.shape[0] + + # Create comprehensive dataframe with all metrics + comprehensive_data = [] + for base_reaction in base_flux_df.columns: + cat_info = base_reaction_categories[base_reaction] + details = base_reaction_details[base_reaction] + comprehensive_data.append( + { + "base_reaction": base_reaction, + "category": cat_info["category"], + "zero_ratio": zero_ratio[base_reaction], + "min_flux": cat_info["min_flux"], + "max_flux": cat_info["max_flux"], + "max_abs_flux": cat_info["max_abs_flux"], + "log_max_abs_flux": np.log10( + epsilon_log + cat_info["max_abs_flux"] + ), + "has_positive": cat_info["has_positive"], + "has_negative": cat_info["has_negative"], + "has_zero": cat_info["has_zero"], + "n_forward_extended": details["n_forward_extended"], + "n_reverse_extended": details["n_reverse_extended"], + "total_extended": details["total_extended"], + } + ) + + comprehensive_df = pd.DataFrame(comprehensive_data) + + # Filter out base reactions with high zero_ratio + active_base_reactions = comprehensive_df[ + comprehensive_df["zero_ratio"] < zero_threshold + ]["base_reaction"].tolist() + active_base_flux_df = base_flux_df[active_base_reactions] + print( + f"\n[INFO] Base reactions remaining after filtering (zero_ratio < {zero_threshold}): {len(active_base_reactions)}" + ) + + # Create separate datasets for visualization + positive_df = comprehensive_df[ + (comprehensive_df["category"] == "always_positive") + & (comprehensive_df["max_flux"] > 0) + ].copy() + + negative_df = comprehensive_df[ + (comprehensive_df["category"] == "always_negative") + & (comprehensive_df["max_abs_flux"] > 0) + ].copy() + + oscillating_df = comprehensive_df[ + comprehensive_df["category"] == "oscillating" + ].copy() + always_zero_df = comprehensive_df[ + comprehensive_df["category"] == "always_zero" + ].copy() + + # Print filtering results + print("\n[INFO] Filtered datasets for visualization:") + print(f" - Always positive with max flux > 0: {len(positive_df)}") + print(f" - Always negative with max abs flux > 0: {len(negative_df)}") + print(f" - Oscillating: {len(oscillating_df)}") + print(f" - Always zero: {len(always_zero_df)}") + + # Print top/bottom reactions for each category + print_base_reaction_category_summaries( + positive_df, negative_df, oscillating_df, always_zero_df + ) + + # Create visualizations with marginal histograms (unchanged behavior) + if len(positive_df) > 0: + create_base_reaction_flux_plots( + positive_df, + "Always Positive", + "always_positive", + epsilon_log, + base_reaction_details, + outdir, + ) + else: + print( + "\n[WARNING] No always positive base reactions found. Skipping positive plots." + ) + + if len(negative_df) > 0: + create_base_reaction_flux_plots( + negative_df, + "Always Negative", + "always_negative", + epsilon_log, + base_reaction_details, + outdir, + ) + else: + print( + "\n[WARNING] No always negative base reactions found. Skipping negative plots." + ) + + # Save the standard outputs (comprehensive and per-category CSVs, mapping, fluxes, metadata) + save_base_reaction_results( + comprehensive_df, + oscillating_df, + always_zero_df, + base_reaction_details, + base_reaction_mapping, + active_base_flux_df, + metadata, + outdir, + ) + + # + # NEW: Per-generation category extraction, intersections across generations, + # and burst detection; save results into separate single-column CSV files. + # + print( + "\n[INFO] Computing per-generation categories, intersections across generations, and burst detection..." + ) + + # 1) get list of distinct generations from the history_sql + generations = _get_distinct_generations(conn, history_sql) + print(f"[INFO] Found generations: {generations}") + + # Prepare a dict to collect sets per generation for each category + per_gen_category_sets = {} + # categories we care about + categories = [ + "always_zero", + "always_positive", + "always_negative", + "oscillating", + ] + for g in generations: + per_gen_category_sets[g] = {cat: set() for cat in categories} + + # Also prepare dicts to collect per-generation sets of reactions that satisfy zero_ratio > burst_threshold + per_gen_burst_candidate_sets = {} + for g in generations: + per_gen_burst_candidate_sets[g] = { + "always_positive": set(), + "always_negative": set(), + "oscillating": set(), + } + + # Make sure outdir exists + os.makedirs(outdir, exist_ok=True) + + # Helper to write a set of base names to a single-column CSV with header 'base_reaction_name' + def _write_one_column_csv(base_names_set: set, filepath: str): + if not base_names_set: + # create empty dataframe with correct column + pd.DataFrame(columns=["base_reaction_name"]).to_csv( + filepath, index=False + ) + else: + df_out = pd.DataFrame( + sorted(list(base_names_set)), columns=["base_reaction_name"] + ) + df_out.to_csv(filepath, index=False) + + # If no generation information found, still create the common and burst CSVs empty + if len(generations) == 0: + print( + "[WARNING] No generations found in history_sql. Creating empty CSV files for common categories and bursts." + ) + # Common (four categories) + _write_one_column_csv( + set(), os.path.join(outdir, "common_base_reactions_always_zero.csv") + ) + _write_one_column_csv( + set(), os.path.join(outdir, "common_base_reactions_always_positive.csv") + ) + _write_one_column_csv( + set(), os.path.join(outdir, "common_base_reactions_always_negative.csv") + ) + _write_one_column_csv( + set(), os.path.join(outdir, "common_base_reactions_oscillating.csv") + ) + # Burst files (three categories) + _write_one_column_csv( + set(), os.path.join(outdir, "burst_base_reaction_always_positive.csv") + ) + _write_one_column_csv( + set(), os.path.join(outdir, "burst_base_reaction_always_negative.csv") + ) + _write_one_column_csv( + set(), os.path.join(outdir, "burst_base_reaction_oscillating.csv") + ) + print(f"[INFO] Empty CSVs written to: {outdir}") + else: + # For each generation, load its flux data and compute categories and zero ratios + reaction_ids = metadata["extended_reaction_names"] + processed_generations = [] + for g in generations: + print(f"\n[INFO] Processing generation: {g}") + gen_df = _load_generation_flux_df(conn, history_sql, reaction_ids, g) + if gen_df.empty: + print( + f"[WARNING] Generation {g} had no data after filtering time==0. Skipping." + ) + continue + + processed_generations.append(g) + + # Get extended flux data (drop time column) + gen_extended_flux_df = gen_df.drop(columns=["time"]) + + # Map extended to base reactions for this generation (mapping function is unchanged) + gen_base_reaction_mapping, gen_extended_to_base_map = ( + map_extended_to_base_reactions( + gen_extended_flux_df.columns.tolist(), + metadata["base_to_extended_mapping"], + ) + ) + + # Compute base fluxes for this generation + gen_base_flux_df, gen_base_reaction_details = ( + compute_base_reaction_fluxes( + gen_extended_flux_df, gen_base_reaction_mapping + ) + ) + + # Compute zero_ratio per base reaction in this generation (abs <= eps) + # number of timepoints for the generation: + n_timepoints_gen = ( + gen_base_flux_df.shape[0] if gen_base_flux_df.shape[0] > 0 else 1 + ) + gen_zero_counts = (np.abs(gen_base_flux_df) <= eps).sum(axis=0) + gen_zero_ratio = gen_zero_counts / n_timepoints_gen # pandas Series + + # Categorize base reactions for this generation using SAME logic (eps retained) + ( + gen_always_positive, + gen_always_negative, + gen_oscillating, + gen_always_zero, + gen_base_reaction_categories, + ) = categorize_base_reactions_by_flux_behavior(gen_base_flux_df, eps) + + # Collect sets of category membership for this generation + per_gen_category_sets[g]["always_zero"] = set(gen_always_zero) + per_gen_category_sets[g]["always_positive"] = set(gen_always_positive) + per_gen_category_sets[g]["always_negative"] = set(gen_always_negative) + per_gen_category_sets[g]["oscillating"] = set(gen_oscillating) + + # For burst detection: within this generation, a reaction is a burst candidate + # if it belongs to the category and its zero_ratio > burst_threshold. + # We record such candidates per generation for each category of interest. + for rxn in gen_always_positive: + # default to 0 if missing + zr = float(gen_zero_ratio.get(rxn, 1.0)) + if zr > burst_threshold: + per_gen_burst_candidate_sets[g]["always_positive"].add(rxn) + for rxn in gen_always_negative: + zr = float(gen_zero_ratio.get(rxn, 1.0)) + if zr > burst_threshold: + per_gen_burst_candidate_sets[g]["always_negative"].add(rxn) + for rxn in gen_oscillating: + zr = float(gen_zero_ratio.get(rxn, 1.0)) + if zr > burst_threshold: + per_gen_burst_candidate_sets[g]["oscillating"].add(rxn) + + print( + f"[INFO] Generation {g} category counts: " + f"always_zero={len(gen_always_zero)}, always_positive={len(gen_always_positive)}, " + f"always_negative={len(gen_always_negative)}, oscillating={len(gen_oscillating)}" + ) + print( + f"[INFO] Generation {g} burst candidate counts (zr > {burst_threshold}): " + f"always_positive={len(per_gen_burst_candidate_sets[g]['always_positive'])}, " + f"always_negative={len(per_gen_burst_candidate_sets[g]['always_negative'])}, " + f"oscillating={len(per_gen_burst_candidate_sets[g]['oscillating'])}" + ) + + # Use only processed_generations (those with data) + if len(processed_generations) == 0: + print( + "[WARNING] No generations had usable data. Writing empty CSVs for common and burst outputs." + ) + # Common (four categories) + _write_one_column_csv( + set(), os.path.join(outdir, "common_base_reactions_always_zero.csv") + ) + _write_one_column_csv( + set(), + os.path.join(outdir, "common_base_reactions_always_positive.csv"), + ) + _write_one_column_csv( + set(), + os.path.join(outdir, "common_base_reactions_always_negative.csv"), + ) + _write_one_column_csv( + set(), os.path.join(outdir, "common_base_reactions_oscillating.csv") + ) + # Burst files (three categories) + _write_one_column_csv( + set(), + os.path.join(outdir, "burst_base_reaction_always_positive.csv"), + ) + _write_one_column_csv( + set(), + os.path.join(outdir, "burst_base_reaction_always_negative.csv"), + ) + _write_one_column_csv( + set(), os.path.join(outdir, "burst_base_reaction_oscillating.csv") + ) + print(f"[INFO] Empty CSVs written to: {outdir}") + else: + # Compute intersection across processed generations for each category (common sets) + common_by_category = {} + for cat in categories: + sets_list = [ + per_gen_category_sets[g][cat] for g in processed_generations + ] + if not sets_list: + common = set() + else: + common = sets_list[0].copy() + for s in sets_list[1:]: + common &= s + common_by_category[cat] = common + print( + f"[INFO] Common across processed generations for '{cat}': {len(common)}" + ) + + # Save the common sets (one-column CSVs) + _write_one_column_csv( + common_by_category.get("always_zero", set()), + os.path.join(outdir, "common_base_reactions_always_zero.csv"), + ) + _write_one_column_csv( + common_by_category.get("always_positive", set()), + os.path.join(outdir, "common_base_reactions_always_positive.csv"), + ) + _write_one_column_csv( + common_by_category.get("always_negative", set()), + os.path.join(outdir, "common_base_reactions_always_negative.csv"), + ) + _write_one_column_csv( + common_by_category.get("oscillating", set()), + os.path.join(outdir, "common_base_reactions_oscillating.csv"), + ) + print( + f"[INFO] Common base reaction CSV files (one per category) saved into: {outdir}" + ) + + # Compute burst intersection across processed generations for each burst category + burst_categories = ["always_positive", "always_negative", "oscillating"] + burst_common_by_category = {} + for bcat in burst_categories: + sets_list = [ + per_gen_burst_candidate_sets[g][bcat] + for g in processed_generations + ] + if not sets_list: + burst_common = set() + else: + burst_common = sets_list[0].copy() + for s in sets_list[1:]: + burst_common &= s + burst_common_by_category[bcat] = burst_common + print( + f"[INFO] Burst-common across processed generations for '{bcat}': {len(burst_common)}" + ) + + # Save burst CSVs (one-column each) + _write_one_column_csv( + burst_common_by_category.get("always_positive", set()), + os.path.join(outdir, "burst_base_reaction_always_positive.csv"), + ) + _write_one_column_csv( + burst_common_by_category.get("always_negative", set()), + os.path.join(outdir, "burst_base_reaction_always_negative.csv"), + ) + _write_one_column_csv( + burst_common_by_category.get("oscillating", set()), + os.path.join(outdir, "burst_base_reaction_oscillating.csv"), + ) + print(f"[INFO] Burst base reaction CSV files saved into: {outdir}") + + print( + "\n[INFO] Base reaction flux preprocessing, visualization, common-category extraction, and burst detection complete." + ) + print(f"[INFO] All files saved to directory: {outdir}") + + # Return same things as original function signature (plus printing above) + return ( + comprehensive_df, + active_base_flux_df, + oscillating_df, + always_zero_df, + base_reaction_details, + base_reaction_mapping, + metadata, + ) + + except Exception as e: + print(f"[ERROR] Analysis failed: {str(e)}") + import traceback + + traceback.print_exc() + return None diff --git a/ecoli/analysis/multigeneration/fba_heat_scatter_combined.py b/ecoli/analysis/multigeneration/fba_heat_scatter_combined.py new file mode 100644 index 000000000..cf9ac9502 --- /dev/null +++ b/ecoli/analysis/multigeneration/fba_heat_scatter_combined.py @@ -0,0 +1,1328 @@ +""" +Base reaction flux analysis script with per-generation common-category extraction and burst detection. + +This script contains the original helper functions for loading FBA data, mapping +extended -> base reactions, computing base reaction net fluxes, categorizing base +reactions by flux behavior (always_positive, always_negative, oscillating, +always_zero), plotting utilities, and saving results. + +Added functionality: +- For each generation present in the history SQL, compute base reaction categories + using the exact same categorization logic. +- Compute the intersection (common base reactions) across all generations for + each category. +- Save those common reaction names into four separate CSV files in outdir: + - common_base_reactions_always_zero.csv + - common_base_reactions_always_positive.csv + - common_base_reactions_always_negative.csv + - common_base_reactions_oscillating.csv +- Compute "burst" base reactions for categories always_positive, always_negative, and oscillating. + A base reaction is considered a burst (for a category) if, for every generation, + when that reaction belongs to that category in that generation, its zero_ratio + (fraction of timepoints with essentially zero flux) is > burst_threshold. + The default burst_threshold is 0.1 (configurable via params["burst_threshold"]). +- Save three CSVs (one-column each) listing burst base reaction names: + - burst_base_reaction_always_positive.csv + - burst_base_reaction_always_negative.csv + - burst_base_reaction_oscillating.csv + +Modified plotting: +- Removed simple scatter plots +- Combined positive and negative scatter plots into unified unidirectional plot + +Important constraints preserved: +- The data loading function `load_fba_data` is left intact (not modified). +- The category determination logic (in `categorize_base_reactions_by_flux_behavior`) + is not altered. +- `plot(...)` remains present and acts as the main entry point. + +Usage: + Call plot(...) with the same parameters expected by the original script. +""" + +import os +from typing import Any +import pandas as pd +import numpy as np +import plotly.graph_objects as go +from plotly.subplots import make_subplots +from scipy.stats import gaussian_kde +from duckdb import DuckDBPyConnection + +# --- BEGIN: Existing helper imports and functions (kept exactly as requested) --- +# Note: These functions are identical to the code you provided. I have not changed +# the internals of load_fba_data, or the categorization logic. +from ecoli.library.parquet_emitter import field_metadata +from ecoli.analysis.utils import create_base_to_extended_mapping + + +def load_fba_data( + conn: DuckDBPyConnection, history_sql: str, config_sql: str, sim_data_dict: dict +) -> tuple[pd.DataFrame, dict]: + """ + Load FBA flux data using DuckDB connection and SQL queries. + + Parameters: + - conn: DuckDB connection + - history_sql: SQL query for historical data + - config_sql: SQL query for configuration data + - sim_data_dict: Dictionary with sim_data information + + Returns: + - df: DataFrame with flux data (time points x extended reactions) + - metadata: Dictionary with experiment metadata + """ + print("[INFO] Loading FBA flux data via SQL...") + + try: + # Create base to extended reaction mapping + base_to_extended_mapping = create_base_to_extended_mapping(sim_data_dict) + if not base_to_extended_mapping: + raise Exception("Could not create base to extended reaction mapping") + + # Load the reaction IDs from the config - this is the array that maps to flux matrix columns + reaction_ids = field_metadata( + conn, config_sql, "listeners__fba_results__reaction_fluxes" + ) + print(f"[INFO] Total reactions in sim_data: {len(reaction_ids)}") + + # Required columns for the query + required_columns = [ + "time", + "generation", + "listeners__fba_results__reaction_fluxes", + ] + + # Build SQL query + sql = f""" + SELECT {", ".join(required_columns)} + FROM ({history_sql}) + ORDER BY generation, time + """ + + # Execute query + df_pl = conn.sql(sql).pl() + + if df_pl.is_empty(): + raise Exception("No data found") + print(f"[INFO] Loaded data with {df_pl.height} time steps") + + # Extract flux matrix and convert to pandas DataFrame + flux_matrix = df_pl["listeners__fba_results__reaction_fluxes"].to_numpy() + flux_matrix = np.array([np.array(row) for row in flux_matrix]) + + # Create DataFrame with extended reactions + df = pd.DataFrame(flux_matrix, columns=reaction_ids) + + # Add time information + time_data = df_pl.select(["time"]).to_pandas() + df = pd.concat([time_data, df], axis=1) + + # Drop initial time point where time == 0 + df = df[df["time"] != 0].reset_index(drop=True) + + # Create metadata dictionary + metadata = { + "n_extended_reactions": len(reaction_ids), + "n_timepoints": len(df), + "extended_reaction_names": reaction_ids, + "base_to_extended_mapping": base_to_extended_mapping, + } + + print("[INFO] Successfully loaded data:") + print(f" - Time points: {len(df)}") + print(f" - Extended reactions: {len(reaction_ids)}") + print(f" - Base reaction mapping entries: {len(base_to_extended_mapping)}") + + return df, metadata + + except Exception as e: + print(f"[ERROR] Failed to load data: {str(e)}") + raise + + +def map_extended_to_base_reactions(extended_reactions, base_to_extended_mapping): + """ + Map extended reactions to base reactions and identify forward/reverse relationships. + + Parameters: + - extended_reactions: List of extended reaction names + - base_to_extended_mapping: Dict mapping base reaction ID to list of extended reaction names + + Returns: + - base_reaction_mapping: Dict with base reaction info including forward/reverse extended reactions + - extended_to_base_map: Dict mapping each extended reaction to its base reaction + """ + # Create reverse mapping from extended to base + extended_to_base_map = {} + for base_rxn, extended_list in base_to_extended_mapping.items(): + for extended_rxn in extended_list: + extended_to_base_map[extended_rxn] = base_rxn + + base_reaction_mapping = {} + + for extended_reaction in extended_reactions: + # Get base reaction name from mapping + base_reaction = extended_to_base_map.get(extended_reaction) + + if base_reaction is None: + print( + f"[WARNING] No base reaction found for extended reaction: {extended_reaction}" + ) + continue + + # Initialize base reaction entry if not exists + if base_reaction not in base_reaction_mapping: + base_reaction_mapping[base_reaction] = { + "forward_extended": [], + "reverse_extended": [], + "all_extended": [], + } + + # Determine if this is a forward or reverse extended reaction + if extended_reaction.endswith(" (reverse)"): + base_reaction_mapping[base_reaction]["reverse_extended"].append( + extended_reaction + ) + else: + base_reaction_mapping[base_reaction]["forward_extended"].append( + extended_reaction + ) + + base_reaction_mapping[base_reaction]["all_extended"].append(extended_reaction) + + print("[INFO] Base reaction mapping results:") + print(f" - Total base reactions: {len(base_reaction_mapping)}") + print(f" - Extended reactions mapped: {len(extended_to_base_map)}") + + # Print statistics about forward/reverse distributions + forward_only = sum( + 1 + for info in base_reaction_mapping.values() + if len(info["forward_extended"]) > 0 and len(info["reverse_extended"]) == 0 + ) + reverse_only = sum( + 1 + for info in base_reaction_mapping.values() + if len(info["forward_extended"]) == 0 and len(info["reverse_extended"]) > 0 + ) + both_directions = sum( + 1 + for info in base_reaction_mapping.values() + if len(info["forward_extended"]) > 0 and len(info["reverse_extended"]) > 0 + ) + + print(f" - Base reactions with forward extended only: {forward_only}") + print(f" - Base reactions with reverse extended only: {reverse_only}") + print( + f" - Base reactions with both forward and reverse extended: {both_directions}" + ) + + return base_reaction_mapping, extended_to_base_map + + +def compute_base_reaction_fluxes(flux_df, base_reaction_mapping): + """ + Compute base reaction fluxes by summing forward extended and subtracting reverse extended fluxes. + + Parameters: + - flux_df: DataFrame with extended reaction flux values + - base_reaction_mapping: Dict with base reaction info + + Returns: + - base_flux_df: DataFrame with base reaction net flux values + - base_reaction_details: Dict with detailed info about each base reaction + """ + base_flux_data = {} + base_reaction_details = {} + + for base_reaction, info in base_reaction_mapping.items(): + forward_extended = info["forward_extended"] + reverse_extended = info["reverse_extended"] + + # Sum forward extended fluxes + forward_flux = pd.Series(0.0, index=flux_df.index) + if forward_extended: + for ext_reaction in forward_extended: + if ext_reaction in flux_df.columns: + forward_flux += flux_df[ext_reaction] + + # Sum reverse extended fluxes + reverse_flux = pd.Series(0.0, index=flux_df.index) + if reverse_extended: + for ext_reaction in reverse_extended: + if ext_reaction in flux_df.columns: + reverse_flux += flux_df[ext_reaction] + + # Net flux = forward - reverse + net_flux = forward_flux - reverse_flux + base_flux_data[base_reaction] = net_flux + + # Store details for analysis + base_reaction_details[base_reaction] = { + "forward_extended": forward_extended, + "reverse_extended": reverse_extended, + "n_forward_extended": len(forward_extended), + "n_reverse_extended": len(reverse_extended), + "total_extended": len(info["all_extended"]), + } + + base_flux_df = pd.DataFrame(base_flux_data) + + print("[INFO] Base reaction flux computation results:") + print(f" - Base reactions computed: {len(base_flux_df.columns)}") + print(f" - Time points: {len(base_flux_df)}") + + return base_flux_df, base_reaction_details + + +def categorize_base_reactions_by_flux_behavior(base_flux_df, eps=1e-30): + """ + Categorize base reactions based on their flux behavior across time steps. + + Parameters: + - base_flux_df: DataFrame with base reaction net flux values + - eps: Small tolerance for zero comparison + + Returns: + - always_positive: List of base reactions that are always >= 0 and have max > 0 + - always_negative: List of base reactions that are always <= 0 and have max abs > 0 + - oscillating: List of base reactions that change sign + - always_zero: List of base reactions that are always zero + - base_reaction_categories: Dictionary with detailed categorization info + """ + always_positive = [] + always_negative = [] + oscillating = [] + always_zero = [] + base_reaction_categories = {} + + for base_reaction in base_flux_df.columns: + flux_values = base_flux_df[base_reaction].values + + # Check for positive, negative, and zero values + has_positive = np.any(flux_values > eps) + has_negative = np.any(flux_values < -eps) + has_zero = np.any(np.abs(flux_values) <= eps) + + min_flux = flux_values.min() + max_flux = flux_values.max() + max_abs_flux = np.max(np.abs(flux_values)) + + # Categorize based on behavior + if max_abs_flux <= eps: # All values are essentially zero + always_zero.append(base_reaction) + category = "always_zero" + elif not has_negative and has_positive: # All values >= -eps and has some > eps + always_positive.append(base_reaction) + category = "always_positive" + elif not has_positive and has_negative: # All values <= eps and has some < -eps + always_negative.append(base_reaction) + category = "always_negative" + elif has_positive and has_negative: # Has both positive and negative values + oscillating.append(base_reaction) + category = "oscillating" + else: + # This case should be covered by always_zero, but keep as safety net + always_zero.append(base_reaction) + category = "always_zero" + + base_reaction_categories[base_reaction] = { + "category": category, + "min_flux": min_flux, + "max_flux": max_flux, + "max_abs_flux": max_abs_flux, + "has_positive": has_positive, + "has_negative": has_negative, + "has_zero": has_zero, + } + + print("\n[INFO] Base reaction categorization by flux behavior:") + print(f" - Always positive (>= 0, max > 0): {len(always_positive)}") + print(f" - Always negative (<= 0, max abs > 0): {len(always_negative)}") + print(f" - Oscillating (changes sign): {len(oscillating)}") + print(f" - Always zero (max abs ≈ 0): {len(always_zero)}") + + return ( + always_positive, + always_negative, + oscillating, + always_zero, + base_reaction_categories, + ) + + +def print_base_reaction_category_summaries( + positive_df, negative_df, oscillating_df, always_zero_df +): + """Print summary information for each base reaction category.""" + + if len(positive_df) > 0: + print( + "\n[INFO] Always Positive Base Reactions (max flux > 0) - Top 5 most active (lowest zero ratio):" + ) + top_positive = positive_df.sort_values("zero_ratio").head(5) + for _, row in top_positive.iterrows(): + print( + f" {row['base_reaction']}: zero_ratio={row['zero_ratio']:.4f}, max_flux={row['max_flux']:.2e}, ext_reactions={row['total_extended']}" + ) + + if len(negative_df) > 0: + print( + "\n[INFO] Always Negative Base Reactions (max abs flux > 0) - Top 5 most active (lowest zero ratio):" + ) + top_negative = negative_df.sort_values("zero_ratio").head(5) + for _, row in top_negative.iterrows(): + print( + f" {row['base_reaction']}: zero_ratio={row['zero_ratio']:.4f}, max_abs_flux={row['max_abs_flux']:.2e}, ext_reactions={row['total_extended']}" + ) + + if len(oscillating_df) > 0: + print( + "\n[INFO] Oscillating Base Reactions - Top 5 most active (lowest zero ratio):" + ) + top_oscillating = oscillating_df.sort_values("zero_ratio").head(5) + for _, row in top_oscillating.iterrows(): + print( + f" {row['base_reaction']}: zero_ratio={row['zero_ratio']:.4f}, max_abs_flux={row['max_abs_flux']:.2e}, ext_reactions={row['total_extended']}" + ) + + if len(always_zero_df) > 0: + print("\n[INFO] Always Zero Base Reactions - First 5 examples:") + first_zero = always_zero_df.head(5) + for _, row in first_zero.iterrows(): + print( + f" {row['base_reaction']}: max_abs_flux={row['max_abs_flux']:.2e}, ext_reactions={row['total_extended']}" + ) + + +def create_unified_directional_flux_plot( + positive_df, negative_df, epsilon_log, base_reaction_details, outdir +): + """Create unified heat scatter plot combining always positive and always negative base reactions.""" + + # Check if we have data to plot + if len(positive_df) == 0 and len(negative_df) == 0: + print( + "[WARNING] No positive or negative base reactions found. Skipping unified plot." + ) + return + + # Combine the dataframes + combined_data = [] + + # Add positive reactions + for idx, row in positive_df.iterrows(): + combined_data.append( + { + "base_reaction": row["base_reaction"], + "category": "Always Positive", + "zero_ratio": row["zero_ratio"], + "max_abs_flux": row["max_abs_flux"], + "log_max_abs_flux": row["log_max_abs_flux"], + "min_flux": row["min_flux"], + "max_flux": row["max_flux"], + "total_extended": row["total_extended"], + } + ) + + # Add negative reactions + for idx, row in negative_df.iterrows(): + combined_data.append( + { + "base_reaction": row["base_reaction"], + "category": "Always Negative", + "zero_ratio": row["zero_ratio"], + "max_abs_flux": row["max_abs_flux"], + "log_max_abs_flux": row["log_max_abs_flux"], + "min_flux": row["min_flux"], + "max_flux": row["max_flux"], + "total_extended": row["total_extended"], + } + ) + + if not combined_data: + print("[WARNING] No combined data available for unified plot.") + return + + combined_df = pd.DataFrame(combined_data) + + # Prepare data for plotting + x = combined_df["log_max_abs_flux"] + y = combined_df["zero_ratio"] + categories = combined_df["category"] + + # Calculate point density using gaussian_kde for heat scatter plot + if len(combined_df) > 1: # Need at least 2 points for KDE + xy = np.vstack([x, y]) + density = gaussian_kde(xy)(xy) + else: + density = np.array([1.0]) # Single point gets density of 1 + + # Create hover text with base reaction information + hover_text = [] + for idx, row in combined_df.iterrows(): + base_reaction = row["base_reaction"] + details = base_reaction_details.get(base_reaction, {}) + + # Create extended reaction info + forward_ext = details.get("forward_extended", []) + reverse_ext = details.get("reverse_extended", []) + + ext_info = f"Forward: {len(forward_ext)} extended, Reverse: {len(reverse_ext)} extended" + if len(forward_ext) <= 3: + ext_info += ( + f"
Forward: {', '.join(forward_ext) if forward_ext else 'None'}" + ) + if len(reverse_ext) <= 3: + ext_info += ( + f"
Reverse: {', '.join(reverse_ext) if reverse_ext else 'None'}" + ) + + hover_text.append( + f"Base Reaction: {base_reaction}
" + + f"Extended Reactions: {ext_info}
" + + f"Category: {row['category']}
" + + f"Zero Ratio: {row['zero_ratio']:.4f}
" + + f"Max |Net Flux|: {row['max_abs_flux']:.2e}
" + + f"Min Net Flux: {row['min_flux']:.2e}
" + + f"Max Net Flux: {row['max_flux']:.2e}
" + + f"Log |Max Net Flux|: {row['log_max_abs_flux']:.2f}
" + + f"Point Density: {density[idx]:.6f}" + ) + + # Create subplot with marginal histograms + fig_heat = make_subplots( + rows=2, + cols=2, + column_widths=[0.9, 0.1], + row_heights=[0.1, 0.9], + specs=[ + [{"secondary_y": False}, {"secondary_y": False}], + [{"secondary_y": False}, {"secondary_y": False}], + ], + vertical_spacing=0.05, + horizontal_spacing=0.05, + subplot_titles=("", "", "", ""), + ) + + # Define colors for categories + color_map = { + "Always Positive": "rgba(70, 130, 180, 0.8)", # Steel blue + "Always Negative": "rgba(220, 20, 60, 0.8)", # Crimson + } + + # Add scatter traces for each category + for category in ["Always Positive", "Always Negative"]: + mask = categories == category + if not mask.any(): + continue + + category_x = x[mask] + category_y = y[mask] + category_density = density[mask] + category_hover = [hover_text[i] for i in range(len(hover_text)) if mask.iloc[i]] + + fig_heat.add_trace( + go.Scatter( + x=category_x, + y=category_y, + mode="markers", + marker=dict( + size=8, + color=category_density, + colorscale="Plasma", + opacity=0.8, + line=dict( + width=0.5, color=color_map[category].replace("0.8", "1.0") + ), + ), + text=category_hover, + hovertemplate="%{text}", + name=category, + showlegend=True, + ), + row=2, + col=1, + ) + + # Add colorbar for density + fig_heat.data[0].marker.colorbar = dict( + title=dict(text="Point Density", font=dict(size=14)), + tickfont=dict(size=12), + thickness=15, + len=0.7, + x=1.02, # Position colorbar to the right + ) + + # Top density curve (x-axis distribution, row=1, col=1) + if len(x) > 1: + # Create smooth density curve for x-axis + x_range = np.linspace(x.min(), x.max(), 250) + x_density = gaussian_kde(x) + x_density_values = x_density(x_range) + else: + # Single point case + x_range = np.array([x.iloc[0]]) + x_density_values = np.array([1.0]) + + fig_heat.add_trace( + go.Scatter( + x=x_range, + y=x_density_values, + mode="lines", + line=dict(color="steelblue", width=3), + fill="tozeroy", + fillcolor="rgba(70, 130, 180, 0.3)", + name="X Density", + showlegend=False, + ), + row=1, + col=1, + ) + + # Right density curve (y-axis distribution, row=2, col=2) + if len(y) > 1: + # Create smooth density curve for y-axis + y_range = np.linspace(y.min(), y.max(), 250) + y_density = gaussian_kde(y) + y_density_values = y_density(y_range) + else: + # Single point case + y_range = np.array([y.iloc[0]]) + y_density_values = np.array([1.0]) + + fig_heat.add_trace( + go.Scatter( + x=y_density_values, + y=y_range, + mode="lines", + line=dict(color="lightcoral", width=3), + fill="tozerox", + fillcolor="rgba(240, 128, 128, 0.3)", + name="Y Density", + showlegend=False, + ), + row=2, + col=2, + ) + + # Update layout for heat scatter with histograms + fig_heat.update_layout( + title=dict( + text="Unified Directional Heat Scatter Plot: Always Positive and Always Negative Base Reaction Net Flux", + font=dict(size=18), + x=0.5, + xanchor="center", + ), + plot_bgcolor="white", + paper_bgcolor="white", + font=dict(family="Arial", size=12), + width=1000, + height=800, + margin=dict(l=80, r=120, t=100, b=80), + ) + + # Update axes for main plot + fig_heat.update_xaxes( + title=dict(text="log₁₀(ε + |Max Net Flux|)", font=dict(size=14)), + tickfont=dict(size=12), + gridcolor="rgba(128,128,128,0.2)", + gridwidth=1, + row=2, + col=1, + ) + fig_heat.update_yaxes( + title=dict(text="Zero Flux Ratio", font=dict(size=14)), + tickfont=dict(size=12), + gridcolor="rgba(128,128,128,0.2)", + gridwidth=1, + row=2, + col=1, + ) + + # Update axes for histograms (remove tick labels and titles) + fig_heat.update_xaxes(showticklabels=False, title="", row=1, col=1) + fig_heat.update_yaxes(showticklabels=False, title="", row=1, col=1) + fig_heat.update_xaxes(showticklabels=False, title="", row=2, col=2) + fig_heat.update_yaxes(showticklabels=False, title="", row=2, col=2) + + # Hide the top-right subplot + fig_heat.update_xaxes(visible=False, row=1, col=2) + fig_heat.update_yaxes(visible=False, row=1, col=2) + + # Add statistics annotation + stats_text = ( + f"Base Reactions (Always Positive): {len(positive_df):,}
" + + f"Base Reactions (Always Negative): {len(negative_df):,}
" + + f"Total Base Reactions: {len(combined_df):,}
" + + f"ε = {epsilon_log:.0e}
" + + f"|Max Net Flux| Range: {combined_df['max_abs_flux'].min():.2e} to {combined_df['max_abs_flux'].max():.2e}
" + + f"Zero Ratio Range: {combined_df['zero_ratio'].min():.4f} to {combined_df['zero_ratio'].max():.4f}
" + + f"Extended Reactions Range: {combined_df['total_extended'].min()} to {combined_df['total_extended'].max()}" + ) + + if len(combined_df) > 1: + stats_text += f"
Density Range: {density.min():.2e} - {density.max():.2e}" + + fig_heat.add_annotation( + x=0.02, + y=0.48, + xref="paper", + yref="paper", + text=stats_text, + showarrow=False, + font=dict(size=11, color="black"), + bgcolor="rgba(255,255,255,0.9)", + bordercolor="rgba(128,128,128,0.5)", + borderwidth=1, + borderpad=10, + xanchor="left", + yanchor="top", + ) + + # Save plot + filename = os.path.join( + outdir, + "unified_directional_heat_scatter_base_reactions.html", + ) + + fig_heat.write_html(filename) + + print("\n[INFO] Unified directional base reaction flux plot saved:") + print(f" - {filename}") + + +def save_base_reaction_results( + comprehensive_df, + oscillating_df, + always_zero_df, + base_reaction_details, + base_reaction_mapping, + active_base_flux_df, + metadata, + outdir, +): + """Save all base reaction results to CSV files with experiment metadata in outdir.""" + + # Create output directory if it doesn't exist + os.makedirs(outdir, exist_ok=True) + + # Create filename prefix + prefix = "base_reaction_analysis" + + # Save comprehensive base reaction metrics + comprehensive_filename = os.path.join(outdir, f"{prefix}_metrics.csv") + comprehensive_df.to_csv(comprehensive_filename, index=False) + print( + f"\n[INFO] Comprehensive base reaction metrics saved to '{comprehensive_filename}'" + ) + + # Save oscillating base reactions specifically + if len(oscillating_df) > 0: + oscillating_filename = os.path.join(outdir, f"{prefix}_oscillating.csv") + oscillating_df.to_csv(oscillating_filename, index=False) + print(f"[INFO] Oscillating base reactions saved to '{oscillating_filename}'") + else: + print("[INFO] No oscillating base reactions found.") + + # Save always zero base reactions specifically + if len(always_zero_df) > 0: + zero_filename = os.path.join(outdir, f"{prefix}_always_zero.csv") + always_zero_df.to_csv(zero_filename, index=False) + print(f"[INFO] Always zero base reactions saved to '{zero_filename}'") + else: + print("[INFO] No always zero base reactions found.") + + # Save base reaction mapping details + mapping_data = [] + for base_reaction, info in base_reaction_mapping.items(): + mapping_data.append( + { + "base_reaction": base_reaction, + "forward_extended": "; ".join(info["forward_extended"]), + "reverse_extended": "; ".join(info["reverse_extended"]), + "n_forward_extended": len(info["forward_extended"]), + "n_reverse_extended": len(info["reverse_extended"]), + "total_extended": len(info["all_extended"]), + } + ) + + mapping_df = pd.DataFrame(mapping_data) + mapping_filename = os.path.join(outdir, f"{prefix}_extended_mapping.csv") + mapping_df.to_csv(mapping_filename, index=False) + print( + f"[INFO] Base reaction to extended reaction mapping saved to '{mapping_filename}'" + ) + + # Save filtered active base reaction flux data + flux_filename = os.path.join(outdir, f"{prefix}_filtered_flux.csv") + active_base_flux_df.to_csv(flux_filename, index=False, encoding="utf-8-sig") + print(f"[INFO] Filtered active base reactions saved to '{flux_filename}'") + + # Save metadata + metadata_filename = os.path.join(outdir, f"{prefix}_metadata.csv") + metadata_for_csv = { + k: v for k, v in metadata.items() if not isinstance(v, (dict, list, np.ndarray)) + } # Only save simple types + metadata_for_csv["n_base_reactions"] = len(comprehensive_df) + metadata_df = pd.DataFrame([metadata_for_csv]) + metadata_df.to_csv(metadata_filename, index=False) + print(f"[INFO] Experiment metadata saved to '{metadata_filename}'") + + # Print detailed summary statistics by category + print("\n[INFO] Detailed Summary Statistics by Category for Base Reactions:") + + for category in [ + "always_positive", + "always_negative", + "oscillating", + "always_zero", + ]: + cat_df = comprehensive_df[comprehensive_df["category"] == category] + if len(cat_df) > 0: + print( + f"\n {category.replace('_', ' ').title()} Base Reactions ({len(cat_df)}):" + ) + print( + f" Zero ratio range: {cat_df['zero_ratio'].min():.4f} - {cat_df['zero_ratio'].max():.4f}" + ) + print( + f" |Max net flux| range: {cat_df['max_abs_flux'].min():.2e} - {cat_df['max_abs_flux'].max():.2e}" + ) + print( + f" Min net flux range: {cat_df['min_flux'].min():.2e} - {cat_df['min_flux'].max():.2e}" + ) + print( + f" Max net flux range: {cat_df['max_flux'].min():.2e} - {cat_df['max_flux'].max():.2e}" + ) + print( + f" Extended reactions per base: {cat_df['total_extended'].min()} - {cat_df['total_extended'].max()}" + ) + + # Count reactions with different flux behaviors + has_zero_count = cat_df["has_zero"].sum() + print(f" Base reactions with zero flux points: {has_zero_count}") + + # Extended reaction statistics + total_forward_ext = cat_df["n_forward_extended"].sum() + total_reverse_ext = cat_df["n_reverse_extended"].sum() + print(f" Total forward extended reactions: {total_forward_ext}") + print(f" Total reverse extended reactions: {total_reverse_ext}") + + print(f"\nTotal base reactions: {len(comprehensive_df)}") + print(f"Total extended reactions mapped: {metadata['n_extended_reactions']}") + + +# --- END: Existing helper functions --- + +# --- BEGIN: New helpers & updated plot() that compute common base reaction names across generations and burst detection --- + + +def _get_distinct_generations(conn: DuckDBPyConnection, history_sql: str) -> list: + """ + Return a sorted list of distinct generation values from the history SQL. + + We treat history_sql as a subquery (it may already be a SELECT ...). + """ + gen_query = f"SELECT DISTINCT generation FROM ({history_sql}) ORDER BY generation" + gen_pl = conn.sql(gen_query).pl() + if gen_pl.is_empty(): + return [] + gen_df = gen_pl.to_pandas() + # Expect a column named 'generation' + generations = gen_df["generation"].tolist() + return generations + + +def _load_generation_flux_df( + conn: DuckDBPyConnection, history_sql: str, reaction_ids: list, generation: int +) -> pd.DataFrame: + """ + Load flux dataframe for a specific generation. + + Returns DataFrame where columns are: 'time' + reaction_ids (extended reaction columns). + Drops time==0 rows to match the main loader behavior. + """ + # Build SQL for this generation + required_columns = [ + "time", + "generation", + "listeners__fba_results__reaction_fluxes", + ] + sql = f""" + SELECT {", ".join(required_columns)} + FROM ({history_sql}) + WHERE generation = {generation} + ORDER BY time + """ + df_pl = conn.sql(sql).pl() + if df_pl.is_empty(): + return pd.DataFrame() # empty + + flux_matrix = df_pl["listeners__fba_results__reaction_fluxes"].to_numpy() + flux_matrix = np.array([np.array(row) for row in flux_matrix]) + df_ext = pd.DataFrame(flux_matrix, columns=reaction_ids) + + time_data = df_pl.select(["time"]).to_pandas() + df_full = pd.concat([time_data, df_ext], axis=1) + # Drop initial time point where time == 0 to remain consistent + df_full = df_full[df_full["time"] != 0].reset_index(drop=True) + return df_full + + +def plot( + params: dict[str, Any], + conn: DuckDBPyConnection, + history_sql: str, + config_sql: str, + success_sql: str, + sim_data_dict: dict[str, dict[int, str]], + validation_data_paths: list[str], + outdir: str, + variant_metadata: dict[str, dict[int, Any]], + variant_names: dict[str, str], +): + """ + Preprocesses FBA flux data by mapping extended reactions to base reactions, + computes net fluxes for base reactions (forward extended - reverse extended), + categorizes base reactions based on flux behavior, creates unified visualizations, + and additionally computes base reactions that are common across all generations + for each category (always_zero, always_positive, always_negative, oscillating). + + Also computes "burst" base reactions for categories always_positive, + always_negative, and oscillating: a base reaction is considered a burst if + in every generation the reaction (when in that category) has zero_ratio > + burst_threshold. + + Modified plotting: + - Removed simple scatter plots + - Combined positive and negative scatter plots into unified unidirectional plot + + Saves: + - Four CSVs with common reactions per category (one-column: base_reaction_name) + - Three CSVs with burst reactions per category (one-column: base_reaction_name) + + Returns the same tuple as the earlier design plus printing results. + """ + + # Get parameters with defaults + zero_threshold = params.get("zero_threshold", 0.999) + eps = params.get("eps", 1e-30) + epsilon_log = params.get("epsilon_log", 1e-30) + burst_threshold = params.get("burst_threshold", 0.1) + + print("[INFO] Starting base reaction flux analysis...") + print( + f"[INFO] Parameters: zero_threshold={zero_threshold}, eps={eps}, epsilon_log={epsilon_log}, burst_threshold={burst_threshold}" + ) + print(f"[INFO] Output directory: {outdir}") + + try: + # Load data (this uses the original loader - we DO NOT modify it) + df, metadata = load_fba_data(conn, history_sql, config_sql, sim_data_dict) + + # Separate flux data (drop time columns) + extended_flux_df = df.drop(columns=["time"]) + + print( + f"[INFO] Original data: {len(extended_flux_df.columns)} extended reactions, {len(extended_flux_df)} time points" + ) + + # Map extended reactions to base reactions using the full set + base_reaction_mapping, extended_to_base_map = map_extended_to_base_reactions( + extended_flux_df.columns.tolist(), metadata["base_to_extended_mapping"] + ) + + # Compute base reaction fluxes for the entire dataset + base_flux_df, base_reaction_details = compute_base_reaction_fluxes( + extended_flux_df, base_reaction_mapping + ) + + # Categorize base reactions by flux behavior for the entire dataset (keeps behavior logic unchanged) + ( + always_positive, + always_negative, + oscillating, + always_zero, + base_reaction_categories, + ) = categorize_base_reactions_by_flux_behavior(base_flux_df, eps) + + # Compute zero-flux ratio per base reaction (using absolute values) + zero_counts = (np.abs(base_flux_df) <= eps).sum(axis=0) + zero_ratio = zero_counts / base_flux_df.shape[0] + + # Create comprehensive dataframe with all metrics + comprehensive_data = [] + for base_reaction in base_flux_df.columns: + cat_info = base_reaction_categories[base_reaction] + details = base_reaction_details[base_reaction] + comprehensive_data.append( + { + "base_reaction": base_reaction, + "category": cat_info["category"], + "zero_ratio": zero_ratio[base_reaction], + "min_flux": cat_info["min_flux"], + "max_flux": cat_info["max_flux"], + "max_abs_flux": cat_info["max_abs_flux"], + "log_max_abs_flux": np.log10( + epsilon_log + cat_info["max_abs_flux"] + ), + "has_positive": cat_info["has_positive"], + "has_negative": cat_info["has_negative"], + "has_zero": cat_info["has_zero"], + "n_forward_extended": details["n_forward_extended"], + "n_reverse_extended": details["n_reverse_extended"], + "total_extended": details["total_extended"], + } + ) + + comprehensive_df = pd.DataFrame(comprehensive_data) + + # Filter out base reactions with high zero_ratio + active_base_reactions = comprehensive_df[ + comprehensive_df["zero_ratio"] < zero_threshold + ]["base_reaction"].tolist() + active_base_flux_df = base_flux_df[active_base_reactions] + print( + f"\n[INFO] Base reactions remaining after filtering (zero_ratio < {zero_threshold}): {len(active_base_reactions)}" + ) + + # Create separate datasets for visualization + positive_df = comprehensive_df[ + (comprehensive_df["category"] == "always_positive") + & (comprehensive_df["max_flux"] > 0) + ].copy() + + negative_df = comprehensive_df[ + (comprehensive_df["category"] == "always_negative") + & (comprehensive_df["max_abs_flux"] > 0) + ].copy() + + oscillating_df = comprehensive_df[ + comprehensive_df["category"] == "oscillating" + ].copy() + always_zero_df = comprehensive_df[ + comprehensive_df["category"] == "always_zero" + ].copy() + + # Print filtering results + print("\n[INFO] Filtered datasets for visualization:") + print(f" - Always positive with max flux > 0: {len(positive_df)}") + print(f" - Always negative with max abs flux > 0: {len(negative_df)}") + print(f" - Oscillating: {len(oscillating_df)}") + print(f" - Always zero: {len(always_zero_df)}") + + # Print top/bottom reactions for each category + print_base_reaction_category_summaries( + positive_df, negative_df, oscillating_df, always_zero_df + ) + + # Create unified directional visualization (combines positive and negative) + create_unified_directional_flux_plot( + positive_df, negative_df, epsilon_log, base_reaction_details, outdir + ) + + # Save the standard outputs (comprehensive and per-category CSVs, mapping, fluxes, metadata) + save_base_reaction_results( + comprehensive_df, + oscillating_df, + always_zero_df, + base_reaction_details, + base_reaction_mapping, + active_base_flux_df, + metadata, + outdir, + ) + + # + # NEW: Per-generation category extraction, intersections across generations, + # and burst detection; save results into separate single-column CSV files. + # + print( + "\n[INFO] Computing per-generation categories, intersections across generations, and burst detection..." + ) + + # 1) get list of distinct generations from the history_sql + generations = _get_distinct_generations(conn, history_sql) + print(f"[INFO] Found generations: {generations}") + + # Prepare a dict to collect sets per generation for each category + per_gen_category_sets = {} + # categories we care about + categories = [ + "always_zero", + "always_positive", + "always_negative", + "oscillating", + ] + for g in generations: + per_gen_category_sets[g] = {cat: set() for cat in categories} + + # Also prepare dicts to collect per-generation sets of reactions that satisfy zero_ratio > burst_threshold + per_gen_burst_candidate_sets = {} + for g in generations: + per_gen_burst_candidate_sets[g] = { + "always_positive": set(), + "always_negative": set(), + "oscillating": set(), + } + + # Make sure outdir exists + os.makedirs(outdir, exist_ok=True) + + # Helper to write a set of base names to a single-column CSV with header 'base_reaction_name' + def _write_one_column_csv(base_names_set: set, filepath: str): + if not base_names_set: + # create empty dataframe with correct column + pd.DataFrame(columns=["base_reaction_name"]).to_csv( + filepath, index=False + ) + else: + df_out = pd.DataFrame( + sorted(list(base_names_set)), columns=["base_reaction_name"] + ) + df_out.to_csv(filepath, index=False) + + # If no generation information found, still create the common and burst CSVs empty + if len(generations) == 0: + print( + "[WARNING] No generations found in history_sql. Creating empty CSV files for common categories and bursts." + ) + # Common (four categories) + _write_one_column_csv( + set(), os.path.join(outdir, "common_base_reactions_always_zero.csv") + ) + _write_one_column_csv( + set(), os.path.join(outdir, "common_base_reactions_always_positive.csv") + ) + _write_one_column_csv( + set(), os.path.join(outdir, "common_base_reactions_always_negative.csv") + ) + _write_one_column_csv( + set(), os.path.join(outdir, "common_base_reactions_oscillating.csv") + ) + # Burst files (three categories) + _write_one_column_csv( + set(), os.path.join(outdir, "burst_base_reaction_always_positive.csv") + ) + _write_one_column_csv( + set(), os.path.join(outdir, "burst_base_reaction_always_negative.csv") + ) + _write_one_column_csv( + set(), os.path.join(outdir, "burst_base_reaction_oscillating.csv") + ) + print(f"[INFO] Empty CSVs written to: {outdir}") + else: + # For each generation, load its flux data and compute categories and zero ratios + reaction_ids = metadata["extended_reaction_names"] + processed_generations = [] + for g in generations: + print(f"\n[INFO] Processing generation: {g}") + gen_df = _load_generation_flux_df(conn, history_sql, reaction_ids, g) + if gen_df.empty: + print( + f"[WARNING] Generation {g} had no data after filtering time==0. Skipping." + ) + continue + + processed_generations.append(g) + + # Get extended flux data (drop time column) + gen_extended_flux_df = gen_df.drop(columns=["time"]) + + # Map extended to base reactions for this generation (mapping function is unchanged) + gen_base_reaction_mapping, gen_extended_to_base_map = ( + map_extended_to_base_reactions( + gen_extended_flux_df.columns.tolist(), + metadata["base_to_extended_mapping"], + ) + ) + + # Compute base fluxes for this generation + gen_base_flux_df, gen_base_reaction_details = ( + compute_base_reaction_fluxes( + gen_extended_flux_df, gen_base_reaction_mapping + ) + ) + + # Compute zero_ratio per base reaction in this generation (abs <= eps) + # number of timepoints for the generation: + n_timepoints_gen = ( + gen_base_flux_df.shape[0] if gen_base_flux_df.shape[0] > 0 else 1 + ) + gen_zero_counts = (np.abs(gen_base_flux_df) <= eps).sum(axis=0) + gen_zero_ratio = gen_zero_counts / n_timepoints_gen # pandas Series + + # Categorize base reactions for this generation using SAME logic (eps retained) + ( + gen_always_positive, + gen_always_negative, + gen_oscillating, + gen_always_zero, + gen_base_reaction_categories, + ) = categorize_base_reactions_by_flux_behavior(gen_base_flux_df, eps) + + # Collect sets of category membership for this generation + per_gen_category_sets[g]["always_zero"] = set(gen_always_zero) + per_gen_category_sets[g]["always_positive"] = set(gen_always_positive) + per_gen_category_sets[g]["always_negative"] = set(gen_always_negative) + per_gen_category_sets[g]["oscillating"] = set(gen_oscillating) + + # For burst detection: within this generation, a reaction is a burst candidate + # if it belongs to the category and its zero_ratio > burst_threshold. + # We record such candidates per generation for each category of interest. + for rxn in gen_always_positive: + # default to 0 if missing + zr = float(gen_zero_ratio.get(rxn, 1.0)) + if zr > burst_threshold: + per_gen_burst_candidate_sets[g]["always_positive"].add(rxn) + for rxn in gen_always_negative: + zr = float(gen_zero_ratio.get(rxn, 1.0)) + if zr > burst_threshold: + per_gen_burst_candidate_sets[g]["always_negative"].add(rxn) + for rxn in gen_oscillating: + zr = float(gen_zero_ratio.get(rxn, 1.0)) + if zr > burst_threshold: + per_gen_burst_candidate_sets[g]["oscillating"].add(rxn) + + print( + f"[INFO] Generation {g} category counts: " + f"always_zero={len(gen_always_zero)}, always_positive={len(gen_always_positive)}, " + f"always_negative={len(gen_always_negative)}, oscillating={len(gen_oscillating)}" + ) + print( + f"[INFO] Generation {g} burst candidate counts (zr > {burst_threshold}): " + f"always_positive={len(per_gen_burst_candidate_sets[g]['always_positive'])}, " + f"always_negative={len(per_gen_burst_candidate_sets[g]['always_negative'])}, " + f"oscillating={len(per_gen_burst_candidate_sets[g]['oscillating'])}" + ) + + # Use only processed_generations (those with data) + if len(processed_generations) == 0: + print( + "[WARNING] No generations had usable data. Writing empty CSVs for common and burst outputs." + ) + # Common (four categories) + _write_one_column_csv( + set(), os.path.join(outdir, "common_base_reactions_always_zero.csv") + ) + _write_one_column_csv( + set(), + os.path.join(outdir, "common_base_reactions_always_positive.csv"), + ) + _write_one_column_csv( + set(), + os.path.join(outdir, "common_base_reactions_always_negative.csv"), + ) + _write_one_column_csv( + set(), os.path.join(outdir, "common_base_reactions_oscillating.csv") + ) + # Burst files (three categories) + _write_one_column_csv( + set(), + os.path.join(outdir, "burst_base_reaction_always_positive.csv"), + ) + _write_one_column_csv( + set(), + os.path.join(outdir, "burst_base_reaction_always_negative.csv"), + ) + _write_one_column_csv( + set(), os.path.join(outdir, "burst_base_reaction_oscillating.csv") + ) + print(f"[INFO] Empty CSVs written to: {outdir}") + else: + # Compute intersection across processed generations for each category (common sets) + common_by_category = {} + for cat in categories: + sets_list = [ + per_gen_category_sets[g][cat] for g in processed_generations + ] + if not sets_list: + common = set() + else: + common = sets_list[0].copy() + for s in sets_list[1:]: + common &= s + common_by_category[cat] = common + print( + f"[INFO] Common across processed generations for '{cat}': {len(common)}" + ) + + # Save the common sets (one-column CSVs) + _write_one_column_csv( + common_by_category.get("always_zero", set()), + os.path.join(outdir, "common_base_reactions_always_zero.csv"), + ) + _write_one_column_csv( + common_by_category.get("always_positive", set()), + os.path.join(outdir, "common_base_reactions_always_positive.csv"), + ) + _write_one_column_csv( + common_by_category.get("always_negative", set()), + os.path.join(outdir, "common_base_reactions_always_negative.csv"), + ) + _write_one_column_csv( + common_by_category.get("oscillating", set()), + os.path.join(outdir, "common_base_reactions_oscillating.csv"), + ) + print( + f"[INFO] Common base reaction CSV files (one per category) saved into: {outdir}" + ) + + # Compute burst intersection across processed generations for each burst category + burst_categories = ["always_positive", "always_negative", "oscillating"] + burst_common_by_category = {} + for bcat in burst_categories: + sets_list = [ + per_gen_burst_candidate_sets[g][bcat] + for g in processed_generations + ] + if not sets_list: + burst_common = set() + else: + burst_common = sets_list[0].copy() + for s in sets_list[1:]: + burst_common &= s + burst_common_by_category[bcat] = burst_common + print( + f"[INFO] Burst-common across processed generations for '{bcat}': {len(burst_common)}" + ) + + # Save burst CSVs (one-column each) + _write_one_column_csv( + burst_common_by_category.get("always_positive", set()), + os.path.join(outdir, "burst_base_reaction_always_positive.csv"), + ) + _write_one_column_csv( + burst_common_by_category.get("always_negative", set()), + os.path.join(outdir, "burst_base_reaction_always_negative.csv"), + ) + _write_one_column_csv( + burst_common_by_category.get("oscillating", set()), + os.path.join(outdir, "burst_base_reaction_oscillating.csv"), + ) + print(f"[INFO] Burst base reaction CSV files saved into: {outdir}") + + print( + "\n[INFO] Base reaction flux preprocessing, visualization, common-category extraction, and burst detection complete." + ) + print(f"[INFO] All files saved to directory: {outdir}") + + # Return same things as original function signature (plus printing above) + return ( + comprehensive_df, + active_base_flux_df, + oscillating_df, + always_zero_df, + base_reaction_details, + base_reaction_mapping, + metadata, + ) + + except Exception as e: + print(f"[ERROR] Analysis failed: {str(e)}") + import traceback + + traceback.print_exc() + return None diff --git a/ecoli/analysis/multigeneration/protein_count.py b/ecoli/analysis/multigeneration/protein_count.py new file mode 100644 index 000000000..fdb6457f9 --- /dev/null +++ b/ecoli/analysis/multigeneration/protein_count.py @@ -0,0 +1,258 @@ +""" +Visualize specific protein counts over time across generations + +You can specify the protein to visualize using the 'protein_id' parameter in params: + "protein_count": { + "protein_id": ["Name1", "Name2", ...], + } +""" + +import altair as alt +import os +from typing import Any +import numpy as np +import pickle + +import polars as pl +from duckdb import DuckDBPyConnection + +from ecoli.library.parquet_emitter import open_arbitrary_sim_data, read_stacked_columns + + +def plot( + params: dict[str, Any], + conn: DuckDBPyConnection, + history_sql: str, + config_sql: str, + success_sql: str, + sim_data_paths: dict[str, dict[int, str]], + validation_data_paths: list[str], + outdir: str, + variant_metadata: dict[str, dict[int, Any]], + variant_names: dict[str, str], +): + """Visualize protein counts over time for specified protein ID across generations.""" + + # Load sim_data + with open_arbitrary_sim_data(sim_data_paths) as f: + sim_data = pickle.load(f) + + # Get protein ID from parameters + protein_id = params.get("protein_id") + + if not protein_id: + print("[ERROR] 'protein_id' parameter is required") + return None + + # Get monomer IDs from simulation data + sim_monomer_ids = sim_data.process.translation.monomer_data["id"] + + # Find the index of the target protein + try: + protein_idx = np.where(sim_monomer_ids == protein_id)[0] + if len(protein_idx) == 0: + print(f"[ERROR] Protein ID '{protein_id}' not found in simulation data") + return None + protein_idx = protein_idx[0] + print(f"[INFO] Found protein '{protein_id}' at index {protein_idx}") + except Exception as e: + print(f"[ERROR] Error finding protein ID '{protein_id}': {e}") + return None + + # Step 1: Get protein counts data using the direct column approach + subquery = read_stacked_columns( + history_sql, ["listeners__monomer_counts"], order_results=False + ) + + sql = f""" + WITH unnested_counts AS ( + SELECT unnest(listeners__monomer_counts) AS counts, + generate_subscripts(listeners__monomer_counts, 1) AS idx, + generation, time + FROM ({subquery}) + ) + SELECT time, generation, + counts as protein_count + FROM unnested_counts + WHERE idx = {protein_idx + 1} -- SQL uses 1-based indexing + ORDER BY generation, time + """ + + df = conn.sql(sql).pl() + + # Step 2: Process the data + # Convert time to minutes + df = df.with_columns((pl.col("time") / 60).alias("time_min")) + + # Calculate statistics per generation + generation_stats = df.group_by("generation").agg( + [ + pl.col("protein_count").mean().alias("mean_count"), + pl.col("protein_count").std().alias("std_count"), + pl.col("protein_count").min().alias("min_count"), + pl.col("protein_count").max().alias("max_count"), + pl.col("protein_count").count().alias("n_points"), + ] + ) + + # Convert to pandas for Altair + plot_df_pd = df.to_pandas() + stats_df_pd = generation_stats.to_pandas() + + print(f"Data shape: {plot_df_pd.shape}") + print(f"Generations: {sorted(plot_df_pd['generation'].unique())}") + print( + f"Time range: {plot_df_pd['time_min'].min():.1f} - {plot_df_pd['time_min'].max():.1f} min" + ) + + # ------------------------ # + + # Create line chart for protein counts over time + def create_protein_count_chart(): + base = alt.Chart(plot_df_pd) + + line = ( + base.mark_line(point=True, strokeWidth=2) + .encode( + x=alt.X("time_min:Q", title="Time (min)"), + y=alt.Y("protein_count:Q", title="Protein Count"), + color=alt.Color( + "generation:N", + legend=alt.Legend(title="Generation"), + scale=alt.Scale(scheme="category10"), + ), + tooltip=["time_min:Q", "generation:N", "protein_count:Q", "agent_id:N"], + ) + .properties( + title=f"Protein Count of {protein_id} Over Time Across Generations", + width=800, + height=400, + ) + ) + return line + + # Create line chart showing mean ± std per generation + def create_generation_summary_chart(): + # Calculate time-averaged values per generation + gen_summary = df.group_by("generation").agg( + [ + pl.col("protein_count").mean().alias("mean_count"), + pl.col("protein_count").std().alias("std_count"), + pl.col("time_min").mean().alias("mean_time"), + ] + ) + + gen_summary_pd = gen_summary.to_pandas() + + # Calculate error bars + gen_summary_pd["upper"] = ( + gen_summary_pd["mean_count"] + gen_summary_pd["std_count"] + ) + gen_summary_pd["lower"] = ( + gen_summary_pd["mean_count"] - gen_summary_pd["std_count"] + ) + + base = alt.Chart(gen_summary_pd) + + # Error bars + error_bars = base.mark_errorbar(extent="stdev").encode( + x=alt.X("generation:N", title="Generation"), + y=alt.Y("mean_count:Q", title="Mean Protein Count"), + yError=alt.YError("std_count:Q"), + ) + + # Points + points = base.mark_point(size=100, filled=True).encode( + x=alt.X("generation:N", title="Generation"), + y=alt.Y("mean_count:Q", title="Mean Protein Count"), + tooltip=["generation:N", "mean_count:Q", "std_count:Q"], + ) + + chart = (error_bars + points).properties( + title=f"Mean Protein Count of {protein_id} by Generation", + width=600, + height=300, + ) + + return chart + + # Create histogram of protein counts by generation + def create_protein_count_histogram(): + hist = ( + alt.Chart(plot_df_pd) + .mark_bar(opacity=0.7) + .encode( + x=alt.X( + "protein_count:Q", bin=alt.Bin(maxbins=30), title="Protein Count" + ), + y=alt.Y("count():Q", title="Frequency"), + color=alt.Color("generation:N", scale=alt.Scale(scheme="category10")), + ) + .properties( + title=f"Distribution of {protein_id} Counts by Generation", + width=600, + height=300, + ) + ) + return hist + + # Create box plot by generation + def create_boxplot_by_generation(): + boxplot = ( + alt.Chart(plot_df_pd) + .mark_boxplot(extent="min-max") + .encode( + x=alt.X("generation:N", title="Generation"), + y=alt.Y("protein_count:Q", title="Protein Count"), + color=alt.Color("generation:N", scale=alt.Scale(scheme="category10")), + ) + .properties( + title=f"Protein Count Distribution of {protein_id} by Generation", + width=600, + height=300, + ) + ) + return boxplot + + # ------------------------ # + + # Generate all charts + main_chart = create_protein_count_chart() + generation_summary = create_generation_summary_chart() + histogram = create_protein_count_histogram() + boxplot = create_boxplot_by_generation() + + # Combine charts in a comprehensive layout + combined_chart = alt.vconcat( + # Main time series chart + main_chart, + # Generation summary and boxplot + alt.hconcat(generation_summary, boxplot), + # Histogram + alt.hconcat(histogram), + title=f"Comprehensive Protein Count Analysis: {protein_id}", + ).resolve_scale(color="independent") + + # Save the visualization + out_path = os.path.join(outdir, "protein_count.html") + combined_chart.save(out_path) + print(f"Saved protein count visualization to: {out_path}") + + # Save summary statistics + stats_path = os.path.join(outdir, "protein_count_stats.csv") + stats_df_pd.to_csv(stats_path, index=False) + print(f"Saved summary statistics to: {stats_path}") + + # Print summary statistics + print(f"\nSummary Statistics for {protein_id}:") + print("=" * 50) + for _, row in stats_df_pd.iterrows(): + gen = int(row["generation"]) + print( + f"Generation {gen}: Mean={row['mean_count']:.2f}, " + f"Std={row['std_count']:.2f}, " + f"Range=[{row['min_count']:.0f}, {row['max_count']:.0f}], " + f"N={row['n_points']}" + ) + + return combined_chart diff --git a/ecoli/analysis/multigeneration/replication.py b/ecoli/analysis/multigeneration/replication.py new file mode 100644 index 000000000..9498bf187 --- /dev/null +++ b/ecoli/analysis/multigeneration/replication.py @@ -0,0 +1,248 @@ +""" +The multigeneration analysis method `replication` +1. Record the DNA polymerase position vs time +2. Record # of pairs of replication forks +3. Record the factors of critical initial mass and dry mass +4. Record # of oriC +""" + +import altair as alt +import os +from typing import Any +import pickle + +from duckdb import DuckDBPyConnection +import polars as pl + +from ecoli.library.parquet_emitter import ( + open_arbitrary_sim_data, + read_stacked_columns, +) + +CRITICAL_N = [1, 2, 4, 8] + +# ----------------------------------------- # + + +def plot( + params: dict[str, Any], + conn: DuckDBPyConnection, + history_sql: str, + config_sql: str, + success_sql: str, + sim_data_dict: dict[str, dict[int, str]], + validation_data_paths: list[str], + outdir: str, + variant_metadata: dict[str, dict[int, Any]], + variant_names: dict[str, str], +): + """Create comprehensive replication visualization plots for E. coli simulation data.""" + # Load sim_data to get genome length + with open_arbitrary_sim_data(sim_data_dict) as f: + sim_data = pickle.load(f) + genome_length = len(sim_data.process.replication.genome_sequence) + + # Define data columns with proper listener names and aliases + data_columns = [ + 'time / 3600 AS "Time (hr)"', + "listeners__replication_data__fork_coordinates AS fork_coordinates", + "listeners__replication_data__number_of_oric AS number_of_oric", + "listeners__mass__cell_mass AS cell_mass", + "listeners__mass__dry_mass AS dry_mass", + "listeners__replication_data__critical_initiation_mass AS critical_initiation_mass", + "listeners__replication_data__critical_mass_per_oric AS critical_mass_per_oric", + ] + + # Load data + plot_data = read_stacked_columns(history_sql, data_columns, conn=conn) + + # Convert to DataFrame + df = pl.DataFrame(plot_data) + + # Process fork coordinates and calculate pairs of forks using Polars + if "fork_coordinates" in df.columns: + df = df.with_columns( + pairs_of_forks=pl.col("fork_coordinates") + .list.eval(~pl.element().is_nan()) + .list.sum() + / 2 + ) + + # Calculate critical mass equivalents + if "cell_mass" in df.columns and "critical_initiation_mass" in df.columns: + df = df.with_columns( + critical_mass_equivalents=( + pl.col("cell_mass") / pl.col("critical_initiation_mass") + ) + ) + + # ----------------------------------------- # + # Create visualization functions + def create_fork_positions_plot(): + """Create DNA polymerase positions scatter plot.""" + if "fork_coordinates" not in df.columns: + return None + + # Explode fork coordinates and filter out NaN values + fork_df = ( + df.select(["Time (hr)", "fork_coordinates"]) + .explode("fork_coordinates") + .filter(~pl.col("fork_coordinates").is_nan()) + .rename({"fork_coordinates": "Position"}) + ) + + if fork_df.height == 0: + return None + return ( + alt.Chart(fork_df) + .mark_circle(size=5, opacity=0.7) + .encode( + x=alt.X("Time (hr):Q", title="Time (hr)"), + y=alt.Y( + "Position:Q", + scale=alt.Scale(domain=[-genome_length / 2, genome_length / 2]), + axis=alt.Axis( + values=[-genome_length / 2, 0, genome_length / 2], + labelExpr="datum.value == 0 ? 'oriC' : (datum.value < 0 ? '-terC' : '+terC')", + ), + title="DNA polymerase position (nt)", + ), + ) + .properties(title="DNA Polymerase Positions", width=600, height=120) + ) + + def create_pairs_of_forks_plot(): + """Create pairs of replication forks line plot.""" + if "pairs_of_forks" not in df.columns: + return None + + return ( + alt.Chart(df) + .mark_line(strokeWidth=2) + .encode( + x=alt.X("Time (hr):Q", title="Time (hr)"), + y=alt.Y( + "pairs_of_forks:Q", + scale=alt.Scale(domain=[0, 6]), + title="Pairs of forks", + ), + ) + .properties(title="Pairs of Replication Forks", width=600, height=100) + ) + + def create_critical_mass_plot(): + """Create critical mass equivalents plot with reference lines.""" + if "critical_mass_equivalents" not in df.columns: + return None + + # Main line plot + base_plot = ( + alt.Chart(df) + .mark_line(strokeWidth=2) + .encode( + x=alt.X("Time (hr):Q", title="Time (hr)"), + y=alt.Y( + "critical_mass_equivalents:Q", + title="Factors of critical initiation mass", + ), + ) + ) + + # Reference lines for critical N values + reference_data = pl.DataFrame( + {"y": CRITICAL_N, "label": [f"N={n}" for n in CRITICAL_N]} + ) + + reference_lines = ( + alt.Chart(reference_data) + .mark_rule(strokeDash=[5, 5], color="gray", opacity=0.7) + .encode(y="y:Q") + ) + + # Text labels for reference lines + reference_labels = ( + alt.Chart(reference_data) + .mark_text(align="left", dx=5, fontSize=10, color="gray") + .encode(y="y:Q", text="label:N") + .transform_calculate(x="0") + .encode(x=alt.X("x:Q")) + ) + + return (base_plot + reference_lines + reference_labels).properties( + title="Factors of Critical Initiation Mass", width=600, height=100 + ) + + def create_mass_plot(column_name: str, title: str, y_title: str): + """Create a generic mass plot.""" + if column_name not in df.columns: + return None + + return ( + alt.Chart(df) + .mark_line(strokeWidth=2) + .encode( + x=alt.X("Time (hr):Q", title="Time (hr)"), + y=alt.Y(f"{column_name}:Q", title=y_title), + ) + .properties(title=title, width=600, height=100) + ) + + # ----------------------------------------- # + # Generate all plots + plots = [] + + # 1. Fork positions + fork_plot = create_fork_positions_plot() + if fork_plot: + plots.append(fork_plot) + + # 2. Pairs of forks + pairs_plot = create_pairs_of_forks_plot() + if pairs_plot: + plots.append(pairs_plot) + + # 3. Critical mass equivalents + critical_plot = create_critical_mass_plot() + if critical_plot: + plots.append(critical_plot) + + # 4. Dry mass + dry_mass_plot = create_mass_plot("dry_mass", "Dry Mass", "Dry mass (fg)") + if dry_mass_plot: + plots.append(dry_mass_plot) + + # 5. Number of oriC + oric_plot = create_mass_plot("number_of_oric", "Number of oriC", "Number of oriC") + if oric_plot: + plots.append(oric_plot) + + # 6. Critical mass per oriC + mass_per_oric_plot = create_mass_plot( + "critical_mass_per_oric", "Critical Mass per oriC", "Critical mass per oriC" + ) + if mass_per_oric_plot: + plots.append(mass_per_oric_plot) + + # Combine plots or create fallback + if plots: + combined_plot = alt.vconcat(*plots).resolve_scale(x="shared") + print(f"Created visualization with {len(plots)} subplots") + else: + # Fallback plot if no data available + fallback_data = pl.DataFrame( + {"x": [0], "y": [0], "text": ["No data available for plotting"]} + ) + combined_plot = ( + alt.Chart(fallback_data) + .mark_text(fontSize=20, color="red") + .encode(x=alt.X("x:Q", axis=None), y=alt.Y("y:Q", axis=None), text="text:N") + .properties(width=600, height=400, title="Replication Data Visualization") + ) + print("No plottable data found - created fallback message") + + # Save the plot + output_path = os.path.join(outdir, "replication_report.html") + combined_plot.save(output_path) + print(f"Saved visualization to: {output_path}") + + return combined_plot diff --git a/ecoli/analysis/multigeneration/ribosome_components.py b/ecoli/analysis/multigeneration/ribosome_components.py new file mode 100644 index 000000000..420106d00 --- /dev/null +++ b/ecoli/analysis/multigeneration/ribosome_components.py @@ -0,0 +1,163 @@ +""" +Record the 30S and 50S component count vs time +""" + +import altair as alt +import os +from typing import Any, Dict + +from duckdb import DuckDBPyConnection +import pickle +import polars as pl + +from ecoli.library.parquet_emitter import ( + field_metadata, + open_arbitrary_sim_data, + named_idx, + read_stacked_columns, +) + +# ----------------------------------------- # + + +def plot( + params: Dict[str, Any], + conn: DuckDBPyConnection, + history_sql: str, + config_sql: str, + success_sql: str, + sim_data_dict: Dict[str, Dict[int, str]], + validation_data_paths: list[str], + outdir: str, + variant_metadata: Dict[str, Dict[int, Any]], + variant_names: Dict[str, str], +): + # Load simulation data + with open_arbitrary_sim_data(sim_data_dict) as f: + sim_data = pickle.load(f) + + # Extract molecule IDs for ribosomal subunits + s30_protein_ids = sim_data.molecule_groups.s30_proteins + s30_16s_rRNA_ids = sim_data.molecule_groups.s30_16s_rRNA + s30_full_complex_id = sim_data.molecule_ids.s30_full_complex + s50_protein_ids = sim_data.molecule_groups.s50_proteins + s50_23s_rRNA_ids = sim_data.molecule_groups.s50_23s_rRNA + s50_5s_rRNA_ids = sim_data.molecule_groups.s50_5s_rRNA + s50_full_complex_id = sim_data.molecule_ids.s50_full_complex + + # Retrieve stoichiometry for each protein subunit + complexation = sim_data.process.complexation + s30_info = complexation.get_monomers(s30_full_complex_id) + s50_info = complexation.get_monomers(s50_full_complex_id) + s30_stoich = dict(zip(s30_info["subunitIds"], s30_info["subunitStoich"])) + s50_stoich = dict(zip(s50_info["subunitIds"], s50_info["subunitStoich"])) + + # Map bulk IDs to SQL column indices + bulk_ids = field_metadata(conn, config_sql, "bulk") + bulk_index = {mid: idx for idx, mid in enumerate(bulk_ids)} + + # Determine column indexes in SQL for rRNAs and complexes + s30_16s_idx = [bulk_index[i] for i in s30_16s_rRNA_ids if i in bulk_index] + s50_23s_idx = [bulk_index[i] for i in s50_23s_rRNA_ids if i in bulk_index] + s50_5s_idx = [bulk_index[i] for i in s50_5s_rRNA_ids if i in bulk_index] + s30_complex_idx = bulk_index[s30_full_complex_id] + s50_complex_idx = bulk_index[s50_full_complex_id] + + # Map monomer counts IDs to SQL column indices + mono_ids = field_metadata(conn, config_sql, "listeners__monomer_counts") + mono_index = {mid: idx for idx, mid in enumerate(mono_ids)} + s30_protein_idx = [mono_index[i] for i in s30_protein_ids if i in mono_index] + s50_protein_idx = [mono_index[i] for i in s50_protein_ids if i in mono_index] + + # Build named_idx spec for reading + bulk_cols = [ + named_idx("bulk", s30_16s_rRNA_ids, [s30_16s_idx]), + named_idx("bulk", s50_23s_rRNA_ids, [s50_23s_idx]), + named_idx("bulk", s50_5s_rRNA_ids, [s50_5s_idx]), + named_idx("bulk", [s30_full_complex_id], [[s30_complex_idx]]), + named_idx("bulk", [s50_full_complex_id], [[s50_complex_idx]]), + ] + protein_cols = [ + named_idx("listeners__monomer_counts", [pid], [[idx]]) + for pid, idx in zip( + s30_protein_ids + s50_protein_ids, s30_protein_idx + s50_protein_idx + ) + ] + additional = ["listeners__unique_molecule_counts__active_ribosome", "time"] + cols = bulk_cols + protein_cols + additional + + # Read time-series data + data = read_stacked_columns(history_sql, cols, conn=conn) + df = pl.DataFrame(data).with_columns(Time_min=pl.col("time") / 60) + + # Sum rRNA counts horizontally + s30_16s = pl.sum_horizontal([pl.col(i) for i in s30_16s_rRNA_ids]) + s50_23s = pl.sum_horizontal([pl.col(i) for i in s50_23s_rRNA_ids]) + s50_5s = pl.sum_horizontal([pl.col(i) for i in s50_5s_rRNA_ids]) + + # Extract complex and active ribosome counts + s30_complex = pl.col(s30_full_complex_id) + s50_complex = pl.col(s50_full_complex_id) + active_ribo = pl.col("listeners__unique_molecule_counts__active_ribosome") + + # Adjust protein counts by stoichiometry + for pid in s30_protein_ids: + df = df.with_columns(**{f"adj_s30_{pid}": pl.col(pid) / s30_stoich[pid]}) + for pid in s50_protein_ids: + df = df.with_columns(**{f"adj_s50_{pid}": pl.col(pid) / s50_stoich[pid]}) + + # Determine limiting protein across subunits + s30_lim = pl.min_horizontal([pl.col(f"adj_s30_{pid}") for pid in s30_protein_ids]) + s50_lim = pl.min_horizontal([pl.col(f"adj_s50_{pid}") for pid in s50_protein_ids]) + + # Calculate total rRNA including complexes and active ribosomes + df = df.with_columns( + s30_16s_total=s30_16s + s30_complex + active_ribo, + s50_23s_total=s50_23s + s50_complex + active_ribo, + s50_5s_total=s50_5s + s50_complex + active_ribo, + s30_limiting=s30_lim, + s50_limiting=s50_lim, + s30_total=s30_complex + active_ribo, + s50_total=s50_complex + active_ribo, + ) + + # ----------------------------------------- # + + plot_cols_30 = ["s30_limiting", "s30_16s_total", "s30_total"] + plot_cols_50 = ["s50_limiting", "s50_23s_total", "s50_5s_total", "s50_total"] + + melt_30 = df.select(["Time_min"] + plot_cols_30).melt( + id_vars="Time_min", variable_name="component", value_name="count" + ) + melt_50 = df.select(["Time_min"] + plot_cols_50).melt( + id_vars="Time_min", variable_name="component", value_name="count" + ) + + chart_30 = ( + alt.Chart(melt_30) + .mark_line() + .encode( + x="Time_min", + y="count", + color=alt.Color("component", title="30S Components"), + ) + .properties(title="30S Component Counts", width=600) + ) + + chart_50 = ( + alt.Chart(melt_50) + .mark_line() + .encode( + x="Time_min", + y="count", + color=alt.Color("component", title="50S Components"), + ) + .properties(title="50S Component Counts", width=600) + ) + + combined = ( + alt.vconcat(chart_30, chart_50) + .resolve_scale(color="independent") + .resolve_legend(color="independent") + ) + combined.save(os.path.join(outdir, "ribosome_components.html")) diff --git a/ecoli/analysis/multigeneration/ribosome_crowding.py b/ecoli/analysis/multigeneration/ribosome_crowding.py new file mode 100644 index 000000000..97b54c2a8 --- /dev/null +++ b/ecoli/analysis/multigeneration/ribosome_crowding.py @@ -0,0 +1,246 @@ +""" +Record the translation probability comparison on Gene EG10184 +""" + +import altair as alt +import os +from typing import Any +import pickle +import polars as pl +from duckdb import DuckDBPyConnection + +from ecoli.library.parquet_emitter import ( + field_metadata, + open_arbitrary_sim_data, + named_idx, +) + +MAX_NUMBER_OF_MONOMERS_TO_PLOT = 300 + +# ----------------------------------------- # + + +def plot( + params: dict[str, Any], + conn: DuckDBPyConnection, + history_sql: str, + config_sql: str, + success_sql: str, + sim_data_dict: dict[str, dict[int, str]], + validation_data_paths: list[str], + outdir: str, + variant_metadata: dict[str, dict[int, Any]], + variant_names: dict[str, str], +): + """ + Comparison of target vs actual translation probabilities for overcrowded mRNAs. + """ + + # Load sim_data for monomer mappings + with open_arbitrary_sim_data(sim_data_dict) as f: + sim_data = pickle.load(f) + + # Get monomer and gene mappings + mRNA_data = sim_data.process.transcription.cistron_data.struct_array + monomer_data = sim_data.process.translation.monomer_data.struct_array + + monomer_to_gene = {} + for mono_id, cistron_id in zip(monomer_data["id"], monomer_data["cistron_id"]): + gene_id = next( + ( + g + for c, g in zip(mRNA_data["id"], mRNA_data["gene_id"]) + if c == cistron_id + ), + "Unknown", + ) + monomer_to_gene[mono_id] = gene_id + + # Get field metadata + try: + field_names = field_metadata( + conn, + config_sql, + "listeners__ribosome_data__target_prob_translation_per_transcript", + ) + except Exception as e: + print(f"[ERROR] Error getting field metadata: {e}") + return + + # First pass: Find overcrowded monomer indices + # If gene X's target > actual at any timepoint t, it'll be marked as overcrowded. + overcrowded_query = f""" + WITH unnested AS ( + SELECT + unnest(listeners__ribosome_data__actual_prob_translation_per_transcript) as actual, + unnest(listeners__ribosome_data__target_prob_translation_per_transcript) as target, + generate_subscripts(listeners__ribosome_data__target_prob_translation_per_transcript, 1) as idx + FROM ({history_sql}) + ) + SELECT DISTINCT idx + FROM unnested + WHERE target > actual + ORDER BY idx + LIMIT {MAX_NUMBER_OF_MONOMERS_TO_PLOT} + """ + + overcrowded_indices = [ + row[0] - 1 for row in conn.execute(overcrowded_query).fetchall() + ] # Convert to 0-based + + if not overcrowded_indices: + print("[INFO] No overcrowded monomers found.") + return + + n_overcrowded_monomers = len(overcrowded_indices) + n_overcrowded_monomers_to_plot = min( + n_overcrowded_monomers, MAX_NUMBER_OF_MONOMERS_TO_PLOT + ) + + print(f"[INFO] Found {n_overcrowded_monomers} overcrowded monomers") + + # Second pass: Get data for overcrowded monomers only + actual_columns = [] + target_columns = [] + + for i, idx in enumerate(overcrowded_indices): + if i >= n_overcrowded_monomers_to_plot: + break + if idx < len(field_names): + monomer_id = field_names[idx] + gene_id = monomer_to_gene.get(monomer_id, "Unknown") + actual_columns.append(f"actual__{gene_id}") + target_columns.append(f"target__{gene_id}") + + actual_expr = named_idx( + "listeners__ribosome_data__actual_prob_translation_per_transcript", + actual_columns, + [overcrowded_indices[: len(actual_columns)]], + ) + + target_expr = named_idx( + "listeners__ribosome_data__target_prob_translation_per_transcript", + target_columns, + [overcrowded_indices[: len(target_columns)]], + ) + + data_query = f"SELECT {actual_expr}, {target_expr}, time FROM ({history_sql})" + df = conn.execute(data_query).fetchdf() + + # ----------------------------------------- # + # Prepare plot data following original format + pl_df = pl.DataFrame(df) + + # Get all probability columns (both actual and target) + prob_columns = actual_columns + target_columns + + # Unpivot the data + plot_df = ( + pl_df.unpivot( + index=["time"], + on=prob_columns, + variable_name="variable", + value_name="Translation_Probability", + ) + .with_columns( + [ + # Split variable name into probability type and gene ID + pl.col("variable") + .str.split_exact("__", 1) + .struct.rename_fields(["Probability_Type", "Gene_ID"]), + (pl.col("time") / 60).alias("Time_min"), + ] + ) + .unnest("variable") + ) + + # Get unique gene IDs in the order they appear in the data + unique_genes = plot_df["Gene_ID"].unique().to_list() + + # ----------------------------------------- # + # Create individual plots for each overcrowded gene + charts = [] + for i, gene_id in enumerate(unique_genes[:n_overcrowded_monomers_to_plot]): + gene_data = plot_df.filter(pl.col("Gene_ID") == gene_id) + + if gene_data.height == 0: + continue + + gene_id = gene_data["Gene_ID"][0] + + # Create chart with simplified encoding and proper tooltip + chart = ( + alt.Chart(gene_data) + .mark_line(point=False, strokeWidth=2) + .encode( + x=alt.X("Time_min:Q", title="Time (min)", scale=alt.Scale(nice=True)), + y=alt.Y( + "Translation_Probability:Q", + title=f"{gene_id} translation probability", + scale=alt.Scale(nice=True), + ), + color=alt.Color( + "Probability_Type:N", + scale=alt.Scale( + # actually, the blue target line will cover the orange actual line if they are the same + domain=["target", "actual"], + range=["#1f77b4", "#ff7f0e"], + ), + legend=alt.Legend(title="Type") if i == 0 else None, + ), + tooltip=[ + alt.Tooltip("Time_min:Q", title="Time (min)", format=".2f"), + alt.Tooltip( + "Translation_Probability:Q", title="Probability", format=".4f" + ), + alt.Tooltip("Probability_Type:N", title="Type"), + alt.Tooltip("Gene_ID:N", title="Gene"), + ], + ) + .properties( + width=600, + height=150, + title=alt.TitleParams( + text=[ + f"Gene {gene_id} - Translation Probability Comparison", + f"Total overcrowded proteins: {n_overcrowded_monomers}" + + ( + f" (showing first {MAX_NUMBER_OF_MONOMERS_TO_PLOT})" + if n_overcrowded_monomers > MAX_NUMBER_OF_MONOMERS_TO_PLOT + else "" + ) + if i == 0 + else "", + ], + fontSize=12, + anchor="start", + ), + ) + ) + + charts.append(chart) + + if charts: + combined_chart = ( + alt.vconcat(*charts) + .resolve_scale(color="independent") + .add_params(alt.selection_interval(bind="scales")) + ) + + alt.data_transformers.enable("json") + + output_path = os.path.join(outdir, "ribosome_crowding.html") + combined_chart.save(output_path) + + print( + f"[INFO] Generated ribosome crowding plot for {len(charts)} overcrowded proteins" + ) + print(f"[INFO] Plot saved to: {output_path}") + + # Also save as JSON for debugging if needed + json_path = os.path.join(outdir, "ribosome_crowding.json") + combined_chart.save(json_path) + print(f"[INFO] Chart specification saved to: {json_path}") + + else: + print("[INFO] No charts created - no data to plot") diff --git a/ecoli/analysis/multigeneration/ribosome_production.py b/ecoli/analysis/multigeneration/ribosome_production.py new file mode 100644 index 000000000..93e2b57df --- /dev/null +++ b/ecoli/analysis/multigeneration/ribosome_production.py @@ -0,0 +1,371 @@ +import os +from typing import Any +import altair as alt +import pickle +import polars as pl +import numpy as np +from duckdb import DuckDBPyConnection +import pandas as pd + +from ecoli.library.parquet_emitter import open_arbitrary_sim_data, named_idx +from ecoli.library.schema import bulk_name_to_idx + + +# ----------------------------------------- # + + +def calc_rna_doubling_time( + produced_col: str, count_col: str, borderline: float +) -> pl.Expr: + """ + Calculate rRNA doubling time with sanitation. + """ + production_rate = pl.col(produced_col) / pl.col("time_step_sec") + growth_rate = production_rate / pl.col(count_col) + dt_min = float(np.log(2)) / growth_rate / 60 + valid = ( + (pl.col(produced_col) >= 0) + & (pl.col(count_col) > 0) + & (growth_rate > 0) + & dt_min.is_finite() + & (dt_min > 0) + & (dt_min < 2 * borderline) + ) + return pl.when(valid).then(dt_min).otherwise(None) + + +def plot( + params: dict[str, Any], + conn: DuckDBPyConnection, + history_sql: str, + config_sql: str, + success_sql: str, + sim_data_dict: dict[str, dict[int, str]], + validation_data_paths: list[str], + outdir: str, + variant_metadata: dict[str, dict[int, Any]], + variant_names: dict[str, str], +): + """Visualize ribosome production metrics for E. coli simulation.""" + with open_arbitrary_sim_data(sim_data_dict) as f: + sim_data = pickle.load(f) + + sim_doubling_time = sim_data.doubling_time.asNumber() + + # define rRNA groups and bulk IDs + s30_16s = list(sim_data.molecule_groups.s30_16s_rRNA) + [ + sim_data.molecule_ids.s30_full_complex + ] + s50_23s = list(sim_data.molecule_groups.s50_23s_rRNA) + [ + sim_data.molecule_ids.s50_full_complex + ] + s50_5s = list(sim_data.molecule_groups.s50_5s_rRNA) + [ + sim_data.molecule_ids.s50_full_complex + ] + bulk_ids = sim_data.internal_state.bulk_molecules.bulk_data["id"].tolist() + + # precompute indices as Python ints + idx_16s = [int(i) for i in np.atleast_1d(bulk_name_to_idx(s30_16s, bulk_ids))] + idx_23s = [int(i) for i in np.atleast_1d(bulk_name_to_idx(s50_23s, bulk_ids))] + idx_5s = [int(i) for i in np.atleast_1d(bulk_name_to_idx(s50_5s, bulk_ids))] + + required_columns = [ + "time", + "variant", + "generation", + "agent_id", + "listeners__mass__instantaneous_growth_rate", + "listeners__mass__dry_mass", + "listeners__ribosome_data__rRNA16S_initiated", + "listeners__ribosome_data__rRNA23S_initiated", + "listeners__ribosome_data__rRNA5S_initiated", + "listeners__ribosome_data__rRNA16S_init_prob", + "listeners__ribosome_data__rRNA23S_init_prob", + "listeners__ribosome_data__rRNA5S_init_prob", + "listeners__ribosome_data__effective_elongation_rate", + "listeners__unique_molecule_counts__active_ribosome", + ] + + # load data + # Create the bulk index expressions + bulk_16s_expr = named_idx("bulk", [f"bulk_{i}" for i in idx_16s], [idx_16s]) + bulk_23s_expr = named_idx("bulk", [f"bulk_{i}" for i in idx_23s], [idx_23s]) + bulk_5s_expr = named_idx("bulk", [f"bulk_{i}" for i in idx_5s], [idx_5s]) + + # Combine all columns and expressions + all_columns = ", ".join(required_columns) + bulk_expressions = ", ".join([bulk_16s_expr, bulk_23s_expr, bulk_5s_expr]) + + # Build the SQL query + sql = f""" + SELECT {all_columns}, {bulk_expressions} + FROM ({history_sql}) + WHERE agent_id = 0 + ORDER BY generation, time + """ + + df = conn.sql(sql).pl() + + # time + df = df.with_columns((pl.col("time") / 60).alias("time_min")) + df = df.with_columns( + pl.col("time") + .diff() + .over(["variant", "generation", "agent_id"]) + .alias("time_step_sec") + ) + df = df.with_columns( + time_step_sec=pl.when(pl.col("time_step_sec").is_null()) + .then(pl.col("time")) + .otherwise(pl.col("time_step_sec")) + ) + + # cell doubling time + if "listeners__mass__instantaneous_growth_rate" in df.columns: + val = ( + float(np.log(2)) / pl.col("listeners__mass__instantaneous_growth_rate") / 60 + ) + df = df.with_columns( + pl.when(val.is_between(0, 2 * sim_doubling_time, closed="both")) + .then(val) + .otherwise(None) + .alias("cell_doubling_time_min") + ) + + df = df.with_columns( + [ + pl.sum_horizontal([pl.col(f"bulk_{i}") for i in idx_16s]).alias( + "bulk_16s_count" + ), + pl.sum_horizontal([pl.col(f"bulk_{i}") for i in idx_23s]).alias( + "bulk_23s_count" + ), + pl.sum_horizontal([pl.col(f"bulk_{i}") for i in idx_5s]).alias( + "bulk_5s_count" + ), + pl.col("listeners__unique_molecule_counts__active_ribosome") + .fill_null(0) + .alias("ribosome_count"), + ] + ) + + # total rRNA + df = df.with_columns( + [ + (pl.col("bulk_16s_count") + pl.col("ribosome_count")).alias("rrn16s_count"), + (pl.col("bulk_23s_count") + pl.col("ribosome_count")).alias("rrn23s_count"), + (pl.col("bulk_5s_count") + pl.col("ribosome_count")).alias("rrn5s_count"), + ] + ) + + # rRNA doubling times + if "listeners__ribosome_data__rRNA16S_initiated" in df.columns: + df = df.with_columns( + rrn16S_doubling_time_min=calc_rna_doubling_time( + "listeners__ribosome_data__rRNA16S_initiated", + "rrn16s_count", + sim_doubling_time, + ) + ) + if "listeners__ribosome_data__rRNA23S_initiated" in df.columns: + df = df.with_columns( + rrn23S_doubling_time_min=calc_rna_doubling_time( + "listeners__ribosome_data__rRNA23S_initiated", + "rrn23s_count", + sim_doubling_time, + ) + ) + if "listeners__ribosome_data__rRNA5S_initiated" in df.columns: + df = df.with_columns( + rrn5S_doubling_time_min=calc_rna_doubling_time( + "listeners__ribosome_data__rRNA5S_initiated", + "rrn5s_count", + sim_doubling_time, + ) + ) + + # reference probabilities + cond = sim_data.condition + trans = sim_data.process.transcription + synth_probs = trans.cistron_tu_mapping_matrix.dot(trans.rna_synth_prob[cond]) + + def fit_prob(group_ids): + cistrons = [rid[:-3] for rid in group_ids] + idxs = np.where(np.isin(trans.cistron_data["id"], cistrons))[0] + return synth_probs[idxs].sum() if idxs.size else 0.0 + + ref_probs = { + "16S": fit_prob(sim_data.molecule_groups.s30_16s_rRNA), + "23S": fit_prob(sim_data.molecule_groups.s50_23s_rRNA), + "5S": fit_prob(sim_data.molecule_groups.s50_5s_rRNA), + } + + # ----------------------------------------- # + # prepare for plotting + plot_cols = ["time_min", "variant", "generation"] + + for c in [ + "listeners__mass__dry_mass", + "cell_doubling_time_min", + "rrn16S_doubling_time_min", + "rrn23S_doubling_time_min", + "rrn5S_doubling_time_min", + "rrn16S_init_prob", + "rrn23S_init_prob", + "rrn5S_init_prob", + "listeners__ribosome_data__effective_elongation_rate", + ]: + if c in df.columns: + plot_cols.append(c) + + plot_df = df.select(plot_cols) + + init_dm = ( + plot_df.filter(pl.col("time_min") == 0) + .select(["variant", "listeners__mass__dry_mass"]) + .rename({"listeners__mass__dry_mass": "initial_dry_mass"}) + ) + plot_df = plot_df.join(init_dm, on=["variant"], how="left") + plot_df = plot_df.with_columns( + (pl.col("listeners__mass__dry_mass") / pl.col("initial_dry_mass")).alias( + "dry_mass_normalized" + ) + ) + + # generate Altair charts + def create_line_chart(y, title, y_title, ref=None): + base = alt.Chart(plot_df) + line = ( + base.mark_line() + .encode( + x=alt.X("time_min:Q", title="Time (min)"), + y=alt.Y(f"{y}:Q", title=y_title), + color=alt.Color( + "generation:N", + legend=alt.Legend(title="Simulated Multigeneration Data"), + ), + ) + .properties(title=title, width=600, height=120) + ) + if ref is not None: + rule = ( + alt.Chart(pd.DataFrame({"y": [ref]})) + .mark_rule(color="red", strokeDash=[5, 5]) + .encode(y="y:Q") + ) + return line + rule + return line + + def create_histogram( + col: str, title: str, bins: int = 30, probability: bool = False + ) -> alt.Chart: + if probability: + density = ( + alt.Chart(plot_df) + .transform_density(col, as_=[col, "density"], counts=False, steps=bins) + .mark_area(opacity=0.6) + .encode( + x=alt.X(f"{col}:Q", title=f"bin={bins}"), + y=alt.Y("density:Q", title="Density"), + ) + .properties(width=200, height=120, title=title) + ) + return density + else: + hist = ( + alt.Chart(plot_df) + .mark_bar(opacity=0.6) + .encode( + x=alt.X(f"{col}:Q", bin=alt.Bin(maxbins=bins), title=f"bin={bins}"), + y=alt.Y("count():Q", title="Count"), + color=alt.value("steelblue"), + ) + .properties(width=200, height=120, title=title) + ) + return hist + + plots = [] + # Dry mass + if "dry_mass_normalized" in plot_df.columns: + line = create_line_chart( + "dry_mass_normalized", + "Normalized Dry Mass Over Time", + "Dry mass (relative to t=0)", + ) + hist = create_histogram( + "dry_mass_normalized", "Normalized Dry Mass Distribution", probability=True + ) + plots.append(alt.hconcat(line, hist)) + # Cell Doubling Time + if "cell_doubling_time_min" in plot_df.columns: + line = create_line_chart( + "cell_doubling_time_min", + "Cell Doubling Time", + "Doubling Time (min)", + sim_doubling_time, + ) + hist = create_histogram( + "cell_doubling_time_min", + "Cell Doubling Time (min) Distribution", + probability=True, + ) + plots.append(alt.hconcat(line, hist)) + # rRNA Doubl;ing Time + for suffix in ["16S", "23S", "5S"]: + col = f"rrn{suffix}_doubling_time_min" + if col in plot_df.columns: + line = create_line_chart( + col, + f"{suffix} rRNA Doubling Time", + "Doubling Time (min)", + sim_doubling_time, + ) + hist = create_histogram( + col, f"{suffix} rRNA Doubling Time Distribution", probability=True + ) + plots.append(alt.hconcat(line, hist)) + # rRNA Initiation Probability + for suffix, ref in ref_probs.items(): + col = f"rrn{suffix}_init_prob" + if col in plot_df.columns: + line = create_line_chart( + col, f"{suffix} rRNA Initiation Probability", "Probability", ref + ) + hist = create_histogram( + col, + f"{suffix} rRNA Initiation Probability Distribution", + probability=True, + ) + plots.append(alt.hconcat(line, hist)) + # Ribosome Elongation Rate + if "listeners__ribosome_data__effective_elongation_rate" in plot_df.columns: + line = create_line_chart( + "listeners__ribosome_data__effective_elongation_rate", + "Ribosome Elongation Rate", + "Amino acids/s", + ) + hist = create_histogram( + "listeners__ribosome_data__effective_elongation_rate", + "Ribosome Elongation Rate Distribution", + probability=True, + ) + plots.append(alt.hconcat(line, hist)) + + if not plots: + fallback = pl.DataFrame({"message": ["No data available"], "x": [0], "y": [0]}) + plots.append( + alt.Chart(fallback) + .mark_text(size=20, color="red") + .encode(x="x:Q", y="y:Q", text="message:N") + .properties(width=600, height=400, title="No Data") + ) + + combined = ( + alt.vconcat(*plots) + .resolve_scale(x="shared", y="independent") + .properties(title="Ribosome Production Metrics") + ) + out_path = os.path.join(outdir, "ribosome_production_report.html") + combined.save(out_path) + print(f"Saved visualization to: {out_path}") + return combined diff --git a/ecoli/analysis/multigeneration/ribosome_usage.py b/ecoli/analysis/multigeneration/ribosome_usage.py new file mode 100644 index 000000000..04bcc8a15 --- /dev/null +++ b/ecoli/analysis/multigeneration/ribosome_usage.py @@ -0,0 +1,442 @@ +""" +Record several things: +1. cell volume over time +2. total / active ribosome count and concentration +3. active ribosome molar / mass fraction +4. Ribosome activation / deactivation count +5. # of AA. be translated +6. the effective ribosome elongation rate +""" + +import altair as alt +import os +from typing import Any +import pickle + +import polars as pl +from duckdb import DuckDBPyConnection +import pandas as pd +import numpy as np + +from ecoli.library.parquet_emitter import open_arbitrary_sim_data, named_idx +from ecoli.library.schema import bulk_name_to_idx + +# ----------------------------------------- # + + +def plot( + params: dict[str, Any], + conn: DuckDBPyConnection, + history_sql: str, + config_sql: str, + success_sql: str, + sim_data_dict: dict[str, dict[int, str]], + validation_data_paths: list[str], + outdir: str, + variant_metadata: dict[str, dict[int, Any]], + variant_names: dict[str, str], +): + """Visualize ribosome usage statistics for E. coli simulation.""" + # Load sim_data + with open_arbitrary_sim_data(sim_data_dict) as f: + sim_data = pickle.load(f) + + # Get molecular IDs for ribosome subunits + complex_ids_30s = [sim_data.molecule_ids.s30_full_complex] + complex_ids_50s = [sim_data.molecule_ids.s50_full_complex] + bulk_ids = sim_data.internal_state.bulk_molecules.bulk_data["id"].tolist() + + # precompute indices as Python ints (following ribosome_production.py pattern) + idx_30s = [ + int(i) for i in np.atleast_1d(bulk_name_to_idx(complex_ids_30s, bulk_ids)) + ] + idx_50s = [ + int(i) for i in np.atleast_1d(bulk_name_to_idx(complex_ids_50s, bulk_ids)) + ] + + # Get molecular weights + n_avogadro = sim_data.constants.n_avogadro + mw_30s = sim_data.getter.get_masses(complex_ids_30s) + mw_50s = sim_data.getter.get_masses(complex_ids_50s) + mw_70s = mw_30s + mw_50s + + required_columns = [ + "time", + "variant", + "generation", + "agent_id", + "experiment_id", + "lineage_seed", + "listeners__mass__instantaneous_growth_rate", + "listeners__mass__cell_mass", + "listeners__mass__volume", + "listeners__ribosome_data__did_initialize", + "listeners__ribosome_data__actual_elongations", + "listeners__ribosome_data__did_terminate", + "listeners__ribosome_data__effective_elongation_rate", + "listeners__unique_molecule_counts__active_ribosome", + ] + + # Create the bulk index expressions + expr_30s = named_idx("bulk", [f"bulk_30s_{i}" for i in idx_30s], [idx_30s]) + expr_50s = named_idx("bulk", [f"bulk_50s_{i}" for i in idx_50s], [idx_50s]) + + # load data + sql = f""" + SELECT + {", ".join(required_columns)}, + {expr_30s}, + {expr_50s} + FROM ({history_sql}) + WHERE agent_id = 0 + ORDER BY generation, time + """ + + df = conn.sql(sql).pl() + + # Convert time + if "time" in df.columns: + df = df.with_columns((pl.col("time") / 60).alias("time_min")) + df = df.with_columns([(pl.col("time") + 1).alias("time_step_sec")]) + + # Calculate ribosome subunit counts + cols_30s = [c for c in df.columns if c.startswith("bulk_30s_")] + cols_50s = [c for c in df.columns if c.startswith("bulk_50s_")] + df = df.with_columns( + [ + # compute bulk ribosome subunit counts + pl.sum_horizontal(cols_30s).alias("counts_30s"), + pl.sum_horizontal(cols_50s).alias("counts_50s"), + # compute unique ribosomes + pl.col("listeners__unique_molecule_counts__active_ribosome") + .fill_null(0) + .alias("active_ribosome_counts"), + ] + ) + + # Calculate total ribosome counts and fractions + df = df.with_columns( + [ + ( + pl.col("active_ribosome_counts") + + pl.min_horizontal(pl.col("counts_30s"), pl.col("counts_50s")) + ).alias("total_ribosome_counts"), + ( + pl.col("active_ribosome_counts").cast(pl.Float64) + / ( + pl.col("active_ribosome_counts") + + pl.min_horizontal(pl.col("counts_30s"), pl.col("counts_50s")) + ) + ).alias("molar_fraction_active"), + ] + ) + + if "listeners__mass__cell_mass" in df.columns: + cell_density = sim_data.constants.cell_density.asNumber() + df = df.with_columns( + (1e-15 * pl.col("listeners__mass__cell_mass") / cell_density).alias( + "cell_volume" + ) + ) + + # Calculate concentrations + df = df.with_columns( + [ + ( + pl.col("total_ribosome_counts") + / n_avogadro.asNumber() + / pl.col("cell_volume") + ).alias("total_ribosome_concentration_mM"), + ( + pl.col("active_ribosome_counts") + / n_avogadro.asNumber() + / pl.col("cell_volume") + ).alias("active_ribosome_concentration_mM"), + ] + ) + + # Calculate masses + mw_30s_value = mw_30s.asNumber() if hasattr(mw_30s, "asNumber") else float(mw_30s) + mw_50s_value = mw_50s.asNumber() if hasattr(mw_50s, "asNumber") else float(mw_50s) + mw_70s_value = mw_70s.asNumber() if hasattr(mw_70s, "asNumber") else float(mw_70s) + + df = df.with_columns( + [ + (pl.col("counts_30s") / n_avogadro.asNumber() * mw_30s_value).alias( + "mass_30s" + ), + (pl.col("counts_50s") / n_avogadro.asNumber() * mw_50s_value).alias( + "mass_50s" + ), + ( + pl.col("active_ribosome_counts") / n_avogadro.asNumber() * mw_70s_value + ).alias("active_ribosome_mass"), + ] + ) + + df = df.with_columns( + [ + ( + pl.col("active_ribosome_mass") + pl.col("mass_30s") + pl.col("mass_50s") + ).alias("total_ribosome_mass"), + ( + pl.col("active_ribosome_mass") + / ( + pl.col("active_ribosome_mass") + + pl.col("mass_30s") + + pl.col("mass_50s") + ) + ).alias("mass_fraction_active"), + ] + ) + + # Calculate rates per time and volume + if "time_step_sec" in df.columns and "cell_volume" in df.columns: + df = df.with_columns( + [ + ( + pl.col("listeners__ribosome_data__did_initialize") + / (pl.col("cell_volume") / 1e-15) + ).alias("activations_per_volume"), + ( + pl.col("listeners__ribosome_data__did_terminate") + / (pl.col("cell_volume") / 1e-15) + ).alias("deactivations_per_volume"), + ] + ) + + # Select columns for plotting + plot_columns = ["time_min", "variant", "generation"] + + # Add other columns that exist + for col in [ + "time_step_sec", + "cell_volume", + "total_ribosome_counts", + "total_ribosome_concentration_mM", + "active_ribosome_counts", + "active_ribosome_concentration_mM", + "molar_fraction_active", + "mass_fraction_active", + "listeners__ribosome_data__did_initialize", + "listeners__ribosome_data__did_terminate", + "activations_per_volume", + "deactivations_per_volume", + "listeners__ribosome_data__actual_elongations", + "listeners__ribosome_data__effective_elongation_rate", + ]: + if col in df.columns: + plot_columns.append(col) + + plot_df = df.select(plot_columns) + + # ----------------------------------------- # + + def create_line_chart(y_field, title, y_title, skip_first_point=False): + """Create line chart with optional skipping of first data point.""" + data = plot_df.to_pandas() + if skip_first_point: + # Group by variant and generation, skip first point of each group + filtered_data = [] + for (variant, generation), group in data.groupby(["variant", "generation"]): + if len(group) > 1: + filtered_data.append(group.iloc[1:]) + else: + filtered_data.append(group) + data = ( + pd.concat(filtered_data, ignore_index=True) if filtered_data else data + ) + + chart = ( + alt.Chart(data) + .mark_line() + .encode( + x=alt.X("time_min:Q", title="Time (min)"), + y=alt.Y(f"{y_field}:Q", title=y_title), + color=alt.Color("generation:N", legend=alt.Legend(title="Generation")), + ) + .properties(title=title, width=600, height=120) + ) + + return chart + + # ----------------------------------------- # + plots = [] + + # Create all 14 plots following the original order + if "time_step_sec" in plot_df.columns: + plots.append( + create_line_chart( + "time_step_sec", "Length of Time Step", "Length of time step (s)" + ) + ) + + if "cell_volume" in plot_df.columns: + plots.append(create_line_chart("cell_volume", "Cell Volume", "Cell volume (L)")) + + if "total_ribosome_counts" in plot_df.columns: + plots.append( + create_line_chart( + "total_ribosome_counts", "Total Ribosome Count", "Total ribosome count" + ) + ) + + if "total_ribosome_concentration_mM" in plot_df.columns: + plots.append( + create_line_chart( + "total_ribosome_concentration_mM", + "Total Ribosome Concentration", + "[Total ribosome] (mM)", + ) + ) + + if "active_ribosome_counts" in plot_df.columns: + plots.append( + create_line_chart( + "active_ribosome_counts", + "Active Ribosome Count", + "Active ribosome count", + skip_first_point=True, + ) + ) + + if "active_ribosome_concentration_mM" in plot_df.columns: + plots.append( + create_line_chart( + "active_ribosome_concentration_mM", + "Active Ribosome Concentration", + "[Active ribosome] (mM)", + skip_first_point=True, + ) + ) + + if "molar_fraction_active" in plot_df.columns: + plots.append( + create_line_chart( + "molar_fraction_active", + "Molar Fraction Active Ribosomes", + "Molar fraction active ribosomes", + skip_first_point=True, + ) + ) + + if "mass_fraction_active" in plot_df.columns: + plots.append( + create_line_chart( + "mass_fraction_active", + "Mass Fraction Active Ribosomes", + "Mass fraction active ribosomes", + skip_first_point=True, + ) + ) + + if "listeners__ribosome_data__did_initialize" in plot_df.columns: + plots.append( + create_line_chart( + "listeners__ribosome_data__did_initialize", + "Ribosome Activations", + "Activations per timestep", + ) + ) + + if "listeners__ribosome_data__did_terminate" in plot_df.columns: + plots.append( + create_line_chart( + "listeners__ribosome_data__did_terminate", + "Ribosome Deactivations", + "Deactivations per timestep", + ) + ) + + if "activations_per_volume" in plot_df.columns: + plots.append( + create_line_chart( + "activations_per_volume", + "Activations per Volume (fL)", + "Activations per Volume (fL)", + ) + ) + + if "deactivations_per_volume" in plot_df.columns: + plots.append( + create_line_chart( + "deactivations_per_volume", + "Deactivations per Volume (fL)", + "Deactivations per Volume (fL)", + ) + ) + + if "listeners__ribosome_data__actual_elongations" in plot_df.columns: + plots.append( + create_line_chart( + "listeners__ribosome_data__actual_elongations", + "Amino Acids Translated", + "AA translated", + ) + ) + + if "listeners__ribosome_data__effective_elongation_rate" in plot_df.columns: + plots.append( + create_line_chart( + "listeners__ribosome_data__effective_elongation_rate", + "Effective Ribosome Elongation Rate", + "Effective elongation rate", + ) + ) + + if not plots: + fallback_df = pl.DataFrame( + { + "message": ["No data available for ribosome usage visualization"], + "x": [0], + "y": [0], + } + ) + fallback_plot = ( + alt.Chart(fallback_df) + .mark_text(size=20, color="red") + .encode(x="x:Q", y="y:Q", text="message:N") + .properties( + width=600, + height=400, + title="Ribosome Usage Statistics - No Data Available", + ) + ) + plots.append(fallback_plot) + + # Arrange plots in 2 columns as in original + left_plots = plots[::2] # Even indices (0, 2, 4, ...) + right_plots = plots[1::2] # Odd indices (1, 3, 5, ...) + + # Ensure both columns have same length by adding empty chart if needed + if len(left_plots) > len(right_plots): + empty_chart = ( + alt.Chart(pl.DataFrame({"x": [0], "y": [0]})) + .mark_point(opacity=0) + .encode(x="x:Q", y="y:Q") + .properties(width=600, height=120) + ) + right_plots.append(empty_chart) + elif len(right_plots) > len(left_plots): + empty_chart = ( + alt.Chart(pl.DataFrame({"x": [0], "y": [0]})) + .mark_point(opacity=0) + .encode(x="x:Q", y="y:Q") + .properties(width=600, height=120) + ) + left_plots.append(empty_chart) + + # Create two column layout + left_column = alt.vconcat(*left_plots) + right_column = alt.vconcat(*right_plots) + combined_plot = ( + alt.hconcat(left_column, right_column) + .resolve_scale(x="shared", y="independent") + .properties(title="Ribosome Usage Statistics") + ) + + output_path = os.path.join(outdir, "ribosome_usage_report.html") + combined_plot.save(output_path) + print(f"Saved visualization to: {output_path}") + + return combined_plot diff --git a/ecoli/analysis/multigeneration/rna_decay_03_high.py b/ecoli/analysis/multigeneration/rna_decay_03_high.py new file mode 100644 index 000000000..758ac0f4f --- /dev/null +++ b/ecoli/analysis/multigeneration/rna_decay_03_high.py @@ -0,0 +1,177 @@ +""" +Plot dynamic traces of genes with high expression (> 20 counts of mRNA) + +EG10367_RNA 24.8 gapA Glyceraldehyde 3-phosphate dehydrogenase +EG11036_RNA 25.2 tufA Elongation factor Tu +EG50002_RNA 26.2 rpmA 50S Ribosomal subunit protein L27 +EG10671_RNA 30.1 ompF Outer membrane protein F +EG50003_RNA 38.7 acpP Apo-[acyl carrier protein] +EG10669_RNA 41.1 ompA Outer membrane protein A +EG10873_RNA 44.7 rplL 50S Ribosomal subunit protein L7/L12 dimer +EG12179_RNA 46.2 cspE Transcription antiterminator and regulator of RNA stability +EG10321_RNA 53.2 fliC Flagellin +EG10544_RNA 97.5 lpp Murein lipoprotein +""" + +import altair as alt +import os +from typing import Any +import pickle +import polars as pl +import numpy as np + +from duckdb import DuckDBPyConnection +from ecoli.library.parquet_emitter import ( + field_metadata, + open_arbitrary_sim_data, + named_idx, + read_stacked_columns, +) + + +def plot( + params: dict[str, Any], + conn: DuckDBPyConnection, + history_sql: str, + config_sql: str, + success_sql: str, + sim_data_dict: dict[str, dict[int, str]], + validation_data_paths: list[str], + outdir: str, + variant_metadata: dict[str, dict[int, Any]], + variant_names: dict[str, str], +): + """Plot dynamic traces of genes with high expression (> 20 counts of mRNA)""" + with open_arbitrary_sim_data(sim_data_dict) as f: + sim_data = pickle.load(f) + cistron_array = sim_data.process.transcription.cistron_data.struct_array + all_ids = list(cistron_array["id"]) + deg_rates = {row["id"]: row["deg_rate"] for row in cistron_array} + + # Define high-expression cistrons + target_ids = [ + "EG10367_RNA", + "EG11036_RNA", + "EG50002_RNA", + "EG10671_RNA", + "EG50003_RNA", + "EG10669_RNA", + "EG10873_RNA", + "EG12179_RNA", + "EG10321_RNA", + "EG10544_RNA", + ] + valid_ids = [cid for cid in target_ids if cid in all_ids] + if not valid_ids: + print("[ERROR] No matching cistrons in sim_data") + return + + # Retrieve metadata for degradation and counts + deg_field = "listeners__rna_degradation_listener__count_RNA_degraded_per_cistron" + cnt_field = "listeners__rna_counts__mRNA_cistron_counts" + try: + deg_meta = field_metadata(conn, config_sql, deg_field) + cnt_meta = field_metadata(conn, config_sql, cnt_field) + except Exception as e: + print(f"[ERROR] field_metadata failed: {e}") + return + + # Find indices for valid cistrons + deg_indices = [deg_meta.index(cid) for cid in valid_ids] + cnt_indices = [cnt_meta.index(cid) for cid in valid_ids] + + # Build named_idx structures + deg_named = named_idx(deg_field, valid_ids, [deg_indices]) + cnt_named = named_idx(cnt_field, [f"{i}_cnt" for i in valid_ids], [cnt_indices]) + + # Read stacked columns + try: + data_dict = read_stacked_columns( + history_sql, + [deg_named, cnt_named], + conn=conn, + ) + except Exception as e: + print(f"[ERROR] read_stacked_columns failed: {e}") + return + + # Convert to Polars DataFrame + df = pl.DataFrame(data_dict) + # convert to minutes + if "time" in df.columns: + df = df.with_columns((pl.col("time") / 60).alias("time_min")) + + # Melt degradation and counts + deg_cols = valid_ids + cnt_cols = [f"{i}_cnt" for i in valid_ids] + deg_df = df.select(["time_min"] + deg_cols).melt( + "time_min", variable_name="cistron", value_name="degraded" + ) + cnt_df = ( + df.select(["time_min"] + cnt_cols) + .melt("time_min", variable_name="cistron", value_name="counts") + .with_columns(pl.col("cistron").str.replace("_cnt", "", literal=True)) + ) + joined = deg_df.join(cnt_df, on=["time_min", "cistron"]) + + # Smooth and fit per cistron + charts = [] + window = 100 + for cid in valid_ids[:9]: + sub = joined.filter(pl.col("cistron") == cid).sort("time_min") + if sub.height < 2 * window: + continue + counts = sub["counts"].to_numpy() + degraded = sub["degraded"].to_numpy() + # smoothing + smooth_c = np.convolve(counts, np.ones(window) / window, mode="same") + dt = np.gradient(sub["time_min"].to_numpy() * 60) + rate = degraded / np.maximum(dt, 1e-10) + smooth_r = np.convolve(rate, np.ones(window) / window, mode="same") + mask = ( + np.isfinite(smooth_c) + & (smooth_c > 0) + & np.isfinite(smooth_r) + & (smooth_r >= 0) + ) + A = smooth_c[mask] + y = smooth_r[mask] + if len(A) < 10: + continue + kdeg = np.linalg.lstsq(A[:, None], y, rcond=None)[0][0] + + # Prepare data for plotting + plot_df = pl.DataFrame({"RNA_counts": A, "RNA_degraded": y}) + # Regression line data + line_x = np.linspace(A.min(), A.max(), 100) + line_y = kdeg * line_x + + # Scatter + scatter = ( + alt.Chart(plot_df) + .mark_circle(size=20, opacity=0.6, color="blue") + .encode(x="RNA_counts:Q", y="RNA_degraded:Q") + ) + # Regression line + line = ( + alt.Chart(pl.DataFrame({"RNA_counts": line_x, "RNA_degraded": line_y})) + .mark_line(color="red", strokeWidth=0.5) + .encode(x="RNA_counts:Q", y="RNA_degraded:Q") + ) + + # Combine and style + title = f"{cid} kdeg meas: {kdeg:.1e} s⁻¹ | kdeg exp: {deg_rates[cid]:.1e} s⁻¹" + charts.append((scatter + line).properties(title=title, width=250, height=200)) + + if charts: + rows = [alt.hconcat(*charts[i : i + 3]) for i in range(0, len(charts), 3)] + combined = alt.vconcat(*rows).properties( + title="RNA Decay - High Expression Genes" + ) + output = os.path.join(outdir, "rna_decay_03_high.html") + combined.save(output) + print(f"[INFO] Saved visualization to: {output}") + return combined + else: + print("[ERROR] No charts generated") + return None diff --git a/ecoli/analysis/multivariant/catalyst_count.py b/ecoli/analysis/multivariant/catalyst_count.py new file mode 100644 index 000000000..d514de6ef --- /dev/null +++ b/ecoli/analysis/multivariant/catalyst_count.py @@ -0,0 +1,528 @@ +""" +Visualize catalyst counts over time for specified BioCyc reactions across generations. +For each specific BioCyc ID reaction, this scripts will add all the catalysts which catalyse it: +```number of catalysts = sum(number of catalysts[i])``` + +Supports two visualization modes: +1. 'grid' mode: Each row represents a variant, each column represents a reaction's catalysts +2. 'stacked' mode: Each reaction's catalysts get their own chart, variants shown as different colored lines + +You can specify the reactions and layout using parameters: + "catalyst_count": { + # Required: specify BioCyc reaction IDs to visualize + "BioCyc_ID": ["Name1", "Name2", ...], + # Optional: specify variants to visualize + # If not specified, all variants will be used + "variant": [1, 2, ...], + # Optional: specify generations to visualize + # If not specified, all generations will be used + "generation": [1, 2, ...], + # Optional: specify layout mode ('grid' or 'stacked') + # Default: 'stacked' + "layout": "stacked" # or "grid" + } + +This script uses SQL to efficiently calculate catalyst counts directly in the database, +reducing memory usage and improving performance. +""" + +import altair as alt +import os +from typing import Any +import pickle + +import polars as pl +from duckdb import DuckDBPyConnection +import pandas as pd + +from ecoli.library.parquet_emitter import open_arbitrary_sim_data, field_metadata +from ecoli.analysis.utils import create_base_to_extended_mapping + + +def plot( + params: dict[str, Any], + conn: DuckDBPyConnection, + history_sql: str, + config_sql: str, + success_sql: str, + sim_data_dict: dict[str, dict[int, str]], + validation_data_paths: list[str], + outdir: str, + variant_metadata: dict[str, dict[int, Any]], + variant_names: dict[str, str], +): + """Visualize catalyst counts over time for specified BioCyc reactions across generations.""" + + # Get parameters + biocyc_ids = params.get("BioCyc_ID", []) + if not biocyc_ids: + print( + "[ERROR] No BioCyc_ID found in params. Please specify reaction IDs to visualize." + ) + return None + + if isinstance(biocyc_ids, str): + biocyc_ids = [biocyc_ids] + + # Get layout mode (default to 'stacked') + layout_mode = params.get("layout", "stacked").lower() + if layout_mode not in ["grid", "stacked"]: + print(f"[WARNING] Invalid layout mode '{layout_mode}'. Using 'stacked' mode.") + layout_mode = "stacked" + + print( + f"[INFO] Visualizing catalyst counts for {len(biocyc_ids)} reactions: {biocyc_ids}" + ) + print(f"[INFO] Using layout mode: {layout_mode}") + + # Load sim_data to get reaction_to_catalyst mapping + try: + with open_arbitrary_sim_data(sim_data_dict) as f: + sim_data = pickle.load(f) + reaction_to_catalyst = sim_data.process.metabolism.reaction_catalysts + print( + f"[INFO] Loaded reaction to catalyst mapping with {len(reaction_to_catalyst)} reactions" + ) + except Exception as e: + print(f"[ERROR] Error loading sim_data: {e}") + return None + + # Create base to extended reaction mapping + base_to_extended_mapping = create_base_to_extended_mapping(sim_data_dict) + if not base_to_extended_mapping: + print("[ERROR] Could not create base to extended reaction mapping") + return None + + # Load catalyst IDs from config + try: + catalyst_ids = field_metadata( + conn, config_sql, "listeners__fba_results__catalyst_counts" + ) + print(f"[INFO] Total catalysts in sim_data: {len(catalyst_ids)}") + except Exception as e: + print(f"[ERROR] Error loading catalyst IDs: {e}") + return None + + # Build catalyst calculation SQL for efficient processing + catalyst_calculation_sql, valid_biocyc_ids = build_catalyst_calculation_sql( + biocyc_ids, + base_to_extended_mapping, + reaction_to_catalyst, + catalyst_ids, + history_sql, + ) + + if not catalyst_calculation_sql or not valid_biocyc_ids: + print("[ERROR] Could not build catalyst calculation SQL") + return None + + print(f"[INFO] Processing {len(valid_biocyc_ids)} valid BioCyc IDs") + + # Execute the optimized SQL query + try: + df = conn.sql(catalyst_calculation_sql).pl() + print(f"[INFO] Loaded data with {df.height} time steps") + except Exception as e: + print(f"[ERROR] Error executing catalyst calculation SQL: {e}") + return None + + if df.is_empty(): + print("[ERROR] No data found") + return None + + # Filter by specified variants and generations if provided + target_variants = params.get("variant", None) + target_generations = params.get("generation", None) + + if target_variants is not None: + print(f"[INFO] Filtering for variants: {target_variants}") + df = df.filter(pl.col("variant").is_in(target_variants)) + + if target_generations is not None: + print(f"[INFO] Filtering for generations: {target_generations}") + df = df.filter(pl.col("generation").is_in(target_generations)) + + # Print variant and generation information + unique_variants = sorted(df["variant"].unique().to_list()) + unique_generations = sorted(df["generation"].unique().to_list()) + print(f"[INFO] Found {len(unique_variants)} variants: {unique_variants}") + print(f"[INFO] Found {len(unique_generations)} generations: {unique_generations}") + + # Calculate average catalyst counts based on layout mode + if layout_mode == "grid": + # For grid mode: calculate averages by variant, generation, and reaction + avg_data = [] + for biocyc_id in valid_biocyc_ids: + catalyst_col = f"{biocyc_id}_catalyst_count" + if catalyst_col in df.columns: + variant_gen_avgs = df.group_by(["variant", "generation"]).agg( + pl.col(catalyst_col).mean().alias("avg_catalyst_count") + ) + + for row in variant_gen_avgs.iter_rows(named=True): + avg_data.append( + { + "biocyc_id": biocyc_id, + "variant": row["variant"], + "generation": row["generation"], + "avg_catalyst_count": row["avg_catalyst_count"], + } + ) + + avg_df = pl.DataFrame(avg_data) + else: + # For stacked mode: calculate averages by variant and reaction + avg_catalyst_counts = {} + for biocyc_id in valid_biocyc_ids: + catalyst_col = f"{biocyc_id}_catalyst_count" + if catalyst_col in df.columns: + variant_avgs = df.group_by("variant").agg( + pl.col(catalyst_col).mean().alias("avg_catalyst_count") + ) + avg_catalyst_counts[biocyc_id] = variant_avgs + + # Create visualization based on layout mode + if layout_mode == "grid": + combined_plot = create_grid_visualization( + df, avg_df, valid_biocyc_ids, unique_variants, unique_generations + ) + output_filename = "catalyst_count_grid_analysis.html" + title_suffix = ( + f"{len(unique_variants)} Variants × {len(valid_biocyc_ids)} Reactions" + ) + else: + combined_plot = create_stacked_visualization( + df, avg_catalyst_counts, valid_biocyc_ids + ) + output_filename = "catalyst_count_stacked_analysis.html" + title_suffix = f"Multi-Variant Analysis ({len(valid_biocyc_ids)} reactions)" + + if combined_plot is None: + print("[ERROR] Failed to create visualization") + return None + + # Add overall title + combined_plot = combined_plot.properties( + title=alt.TitleParams( + text=f"Catalyst Count Analysis: {title_suffix}", + fontSize=16, + anchor="start", + ) + ).resolve_scale(color="shared") + + # Save the plot + output_path = os.path.join(outdir, output_filename) + combined_plot.save(output_path) + print(f"[INFO] Saved visualization to: {output_path}") + + return combined_plot + + +def build_catalyst_calculation_sql( + biocyc_ids, + base_to_extended_mapping, + reaction_to_catalyst, + catalyst_ids, + history_sql, +): + """Build SQL query to efficiently calculate catalyst counts for specified BioCyc reactions.""" + + # Find catalysts for each BioCyc ID and build SQL columns + biocyc_to_catalysts = {} + valid_biocyc_ids = [] + catalyst_calculations = [] + + for biocyc_id in biocyc_ids: + catalysts = set() + + # Get extended reaction IDs for this BioCyc ID + extended_ids = base_to_extended_mapping.get(biocyc_id, []) + + if not extended_ids: + print( + f"[WARNING] No extended reaction IDs found for BioCyc ID: {biocyc_id}" + ) + continue + + # Find catalysts for all extended reactions + for ext_id in extended_ids: + if ext_id in reaction_to_catalyst: + reaction_catalysts = reaction_to_catalyst[ext_id] + catalysts.update(reaction_catalysts) + + if catalysts: + # Convert catalyst IDs to indices in the catalyst_ids array + catalyst_indices = [] + for cat_id in catalysts: + try: + idx = catalyst_ids.index(cat_id) + catalyst_indices.append(idx) + except ValueError: + print( + f"[WARNING] Catalyst {cat_id} not found in catalyst_ids array" + ) + + if catalyst_indices: + biocyc_to_catalysts[biocyc_id] = { + "catalyst_ids": list(catalysts), + "catalyst_indices": catalyst_indices, + } + valid_biocyc_ids.append(biocyc_id) + + # Build SQL calculation for this BioCyc ID + # Convert 0-based indices to 1-based for DuckDB SQL + sql_indices = [str(idx + 1) for idx in catalyst_indices] + catalyst_sum = " + ".join([f"catalysts[{idx}]" for idx in sql_indices]) + # Use quotes around column name to handle special characters like hyphens + catalyst_calculations.append( + f'({catalyst_sum}) AS "{biocyc_id}_catalyst_count"' + ) + + print( + f"[INFO] Found {len(catalyst_indices)} catalysts for {biocyc_id}: {list(catalysts)}" + ) + else: + print(f"[WARNING] No valid catalyst indices found for {biocyc_id}") + else: + print(f"[WARNING] No catalysts found for BioCyc ID: {biocyc_id}") + + if not valid_biocyc_ids or not catalyst_calculations: + print("[ERROR] No valid BioCyc IDs with catalysts found") + return None, [] + + # Build the complete SQL query + catalyst_calculations_str = ",\n ".join(catalyst_calculations) + + sql = f""" + WITH renamed AS ( + SELECT + time / 60.0 AS time_min, + generation, + variant, + listeners__fba_results__catalyst_counts AS catalysts + FROM ({history_sql}) + ) + SELECT + time_min, + generation, + variant, + {catalyst_calculations_str} + FROM renamed + ORDER BY variant, generation, time_min + """ + + print(f"[INFO] Built SQL with {len(catalyst_calculations)} catalyst calculations") + return sql, valid_biocyc_ids + + +def create_grid_visualization( + df, avg_df, valid_biocyc_ids, unique_variants, unique_generations +): + """Create grid layout visualization (rows = variants, columns = reactions).""" + + def create_subplot_chart(variant, biocyc_id): + """Create a single subplot for a specific variant-reaction combination.""" + catalyst_col = f"{biocyc_id}_catalyst_count" + + # Check if the column exists in dataframe + if catalyst_col not in df.columns: + print(f"[WARNING] Column {catalyst_col} not found in dataframe") + return None + + # Filter data for this variant and reaction + subplot_data = ( + df.filter(pl.col("variant") == variant) + .select(["time_min", "generation", catalyst_col]) + .filter(pl.col(catalyst_col).is_not_null()) + ) + + if subplot_data.height == 0: + print(f"[WARNING] No data for variant {variant}, reaction {biocyc_id}") + return None + + # Main line chart with generations as different colors + line_chart = ( + alt.Chart(subplot_data) + .mark_line(strokeWidth=1.5) + .encode( + x=alt.X( + "time_min:Q", + title="Time (min)" if variant == unique_variants[-1] else "", + ), + y=alt.Y( + f"{catalyst_col}:Q", + title="Total Catalyst Count" + if biocyc_id == valid_biocyc_ids[0] + else "", + ), + color=alt.Color( + "generation:N", + legend=alt.Legend(title="Generation") + if variant == unique_variants[0] + and biocyc_id == valid_biocyc_ids[0] + else None, + ), + tooltip=["time_min:Q", f"{catalyst_col}:Q", "generation:N"], + ) + ) + + # For grid mode, we don't show average lines to keep the visualization clean + # Combine all elements + combined = line_chart.resolve_scale(color="shared") + + # Add title only for top row + if variant == unique_variants[0]: + combined = combined.properties(title=f"{biocyc_id}") + + combined = combined.properties(width=400, height=300) + + return combined + + def create_empty_subplot(): + """Create an empty placeholder subplot.""" + return ( + alt.Chart(pd.DataFrame({"x": [0], "y": [0]})) + .mark_point(opacity=0) + .properties(width=200, height=150) + ) + + # Create subplot grid: rows = variants, columns = reactions + subplot_grid = [] + + for variant in unique_variants: + variant_row = [] + for biocyc_id in valid_biocyc_ids: + subplot = create_subplot_chart(variant, biocyc_id) + if subplot is not None: + variant_row.append(subplot) + else: + # Create empty placeholder if no data + variant_row.append(create_empty_subplot()) + + if variant_row: + # Add variant label on the left + variant_label = ( + alt.Chart(pd.DataFrame({"label": [f"Variant {variant}"]})) + .mark_text( + align="center", baseline="middle", fontSize=12, fontWeight="bold" + ) + .encode(text="label:N") + .properties(width=160, height=300) + ) + + # Combine variant label with row of subplots + row_with_label = alt.hconcat(variant_label, *variant_row, spacing=10) + subplot_grid.append(row_with_label) + + if not subplot_grid: + print("[ERROR] No valid subplots could be created") + return None + + # Combine all rows + combined_plot = alt.vconcat(*subplot_grid, spacing=20) + return combined_plot + + +def create_stacked_visualization(df, avg_catalyst_counts, valid_biocyc_ids): + """Create stacked layout visualization (one chart per reaction, variants as colored lines).""" + + def create_catalyst_count_chart(biocyc_id, catalyst_col, variant_avgs): + """Create a line chart for a single reaction's catalyst counts with average lines for each variant.""" + + # Check if the column exists in dataframe + if catalyst_col not in df.columns: + print(f"[WARNING] Column {catalyst_col} not found in dataframe") + return None + + # Select only the columns we need to minimize data transfer + data = df.select(["time_min", "generation", "variant", catalyst_col]) + + # Remove any null values + data = data.filter(pl.col(catalyst_col).is_not_null()) + + if data.height == 0: + print(f"[WARNING] No valid data for reaction {biocyc_id}") + return None + + # Main catalyst count line chart (different variants as different colored lines) + catalyst_chart = ( + alt.Chart(data) + .mark_line(strokeWidth=2) + .encode( + x=alt.X("time_min:Q", title="Time (min)"), + y=alt.Y(f"{catalyst_col}:Q", title="Total Catalyst Count"), + color=alt.Color("variant:N", legend=alt.Legend(title="Variant")), + tooltip=[ + "time_min:Q", + f"{catalyst_col}:Q", + "variant:N", + "generation:N", + ], + ) + ) + + # Create average lines for each variant + avg_line_data = [] + for row in variant_avgs.iter_rows(named=True): + variant_name = row["variant"] + avg_value = row["avg_catalyst_count"] + avg_line_data.append( + { + "variant": variant_name, + "avg_catalyst_count": avg_value, + "label": f"{variant_name} Avg: {avg_value:.2f}", + } + ) + + if avg_line_data: + avg_line_df = pd.DataFrame(avg_line_data) + + avg_lines = ( + alt.Chart(avg_line_df) + .mark_rule(strokeDash=[5, 5], strokeWidth=2) + .encode( + y=alt.Y("avg_catalyst_count:Q"), + color=alt.Color( + "variant:N", legend=None + ), # Use same color scale as main chart + tooltip=["label:N"], + ) + ) + else: + avg_lines = alt.Chart().mark_point() # Empty chart + + # Combine catalyst count line and average lines + combined_chart = ( + (catalyst_chart + avg_lines) + .properties( + title=f"Catalyst Count vs Time: {biocyc_id}", width=600, height=300 + ) + .resolve_scale(y="shared", color="shared") + ) + + return combined_chart + + # Create charts for each reaction + charts = [] + + for biocyc_id in valid_biocyc_ids: + catalyst_col = f"{biocyc_id}_catalyst_count" + variant_avgs = avg_catalyst_counts.get(biocyc_id) + if variant_avgs is not None: + chart = create_catalyst_count_chart(biocyc_id, catalyst_col, variant_avgs) + if chart is not None: + charts.append(chart) + + if not charts: + print("[ERROR] No valid charts could be created") + return None + + # Arrange charts vertically + if len(charts) == 1: + combined_plot = charts[0] + else: + combined_plot = alt.vconcat(*charts).resolve_scale( + x="shared", y="independent", color="shared" + ) + + return combined_plot diff --git a/ecoli/analysis/multivariant/cell_growth_rate.py b/ecoli/analysis/multivariant/cell_growth_rate.py new file mode 100644 index 000000000..4d2c536fe --- /dev/null +++ b/ecoli/analysis/multivariant/cell_growth_rate.py @@ -0,0 +1,436 @@ +""" +Plot cell growth rate (1/hour) over time for multivariant simulation in vEcoli, and: +0. you can specify variants and generations to analyze, like: + \"multivariant\": { + ...... + \"cell_growth_rate\": { + # Optional: specify variants and generations to visualize + # If not specified, all will be used + \"variant\": [0, 1, ...], + \"generation\": [1, 2, ....], + \"show_reference\": true/false + } + ...... + } +1. each variant has its own plot; +2. at each subplot, time is divided by generation id; + +It can also be used at multigeneration analysis. +""" + +import os +from typing import Any +import altair as alt +import pickle +import polars as pl +import numpy as np +from duckdb import DuckDBPyConnection +import pandas as pd + +from ecoli.library.parquet_emitter import open_arbitrary_sim_data + + +# ------------------------------------- # +def calculate_average_growth_rates(df, variant_names, group_by_generation=False): + """ + Calculate average cell growth rate for each variant, optionally grouped by generation. + + Args: + df: Polars DataFrame with processed growth rate data + variant_names: Dictionary mapping variant IDs to names + group_by_generation: If True, group by both variant and generation; + if False, group by variant only + + Returns: + Polars DataFrame with average growth rates + """ + + # Determine grouping columns + group_cols = ["variant"] + sort_cols = ["variant"] + + if group_by_generation: + group_cols.append("generation") + sort_cols.append("generation") + + # Calculate average growth rate + avg_growth_df = ( + df.filter(pl.col("growth_rate_per_hour").is_not_null()) + .group_by(group_cols) + .agg( + [ + pl.col("growth_rate_per_hour").mean().alias("avg_growth_rate"), + pl.col("growth_rate_per_hour").std().alias("std_growth_rate"), + pl.col("growth_rate_per_hour").count().alias("data_points"), + pl.col("growth_rate_per_hour").min().alias("min_growth_rate"), + pl.col("growth_rate_per_hour").max().alias("max_growth_rate"), + ] + ) + .sort(sort_cols) + ) + + # Add variant names + avg_growth_df = avg_growth_df.with_columns( + pl.col("variant") + .map_elements( + lambda x: variant_names.get(str(x), f"Variant {x}"), return_dtype=pl.Utf8 + ) + .alias("variant_name") + ) + + return avg_growth_df + + +def create_growth_rate_comparison_chart( + avg_by_variant, avg_by_variant_gen, ref_growth_rate, show_reference=True +): + """Create charts comparing average growth rates across variants and generations.""" + + # Chart 1: Average by variant only + avg_variant_df = avg_by_variant.to_pandas() + + variant_bars = ( + alt.Chart(avg_variant_df) + .mark_bar() + .encode( + x=alt.X("variant_name:N", title="Variant"), + y=alt.Y("avg_growth_rate:Q", title="Average Growth Rate (1/hour)"), + color=alt.Color( + "variant_name:N", legend=None, scale=alt.Scale(scheme="category10") + ), + tooltip=[ + "variant_name:N", + "avg_growth_rate:Q", + "std_growth_rate:Q", + "data_points:Q", + ], + ) + ) + + variant_error_bars = ( + alt.Chart(avg_variant_df) + .mark_errorbar() + .encode(x="variant_name:N", y="avg_growth_rate:Q", yError="std_growth_rate:Q") + ) + + variant_chart_layers = [variant_bars, variant_error_bars] + if show_reference: + variant_ref_line = ( + alt.Chart(pd.DataFrame({"ref_rate": [ref_growth_rate]})) + .mark_rule(strokeDash=[5, 5], strokeWidth=2, color="red") + .encode( + y="ref_rate:Q", + tooltip=alt.value(f"Reference: {ref_growth_rate:.3f} /hour"), + ) + ) + variant_chart_layers.append(variant_ref_line) + + variant_chart = ( + alt.layer(*variant_chart_layers) + .properties(title="Average Growth Rate by Variant", width=350, height=300) + .resolve_scale(color="independent") + ) + + # Chart 2: Average by variant and generation + avg_gen_df = avg_by_variant_gen.to_pandas() + + # Create a combined label for x-axis + avg_gen_df["variant_gen_label"] = ( + avg_gen_df["variant_name"] + " G" + avg_gen_df["generation"].astype(str) + ) + + gen_bars = ( + alt.Chart(avg_gen_df) + .mark_bar() + .encode( + x=alt.X( + "variant_gen_label:N", + title="Variant - Generation", + sort=alt.Sort(["variant", "generation"]), + ), + y=alt.Y("avg_growth_rate:Q", title="Average Growth Rate (1/hour)"), + color=alt.Color( + "variant_name:N", + legend=alt.Legend(title="Variant"), + scale=alt.Scale(scheme="category10"), + ), + tooltip=[ + "variant_name:N", + "generation:O", + "avg_growth_rate:Q", + "std_growth_rate:Q", + "data_points:Q", + ], + ) + ) + + gen_error_bars = ( + alt.Chart(avg_gen_df) + .mark_errorbar() + .encode( + x="variant_gen_label:N", y="avg_growth_rate:Q", yError="std_growth_rate:Q" + ) + ) + + # Conditional reference line for generation chart + generation_chart_layers = [gen_bars, gen_error_bars] + if show_reference: + gen_ref_line = ( + alt.Chart( + pd.DataFrame( + { + "ref_rate": [ref_growth_rate], + "legend_label": ["Reference Simulation Growth Rate"], + } + ) + ) + .mark_rule(strokeDash=[5, 5], strokeWidth=2) + .encode( + y="ref_rate:Q", + color=alt.Color( + "legend_label:N", + scale=alt.Scale(range=["red"]), + legend=alt.Legend(title="Reference"), + ), + tooltip=alt.value(f"Reference: {ref_growth_rate:.3f} /hour"), + ) + ) + generation_chart_layers.append(gen_ref_line) + + generation_chart = ( + alt.layer(*generation_chart_layers) + .properties( + title="Average Growth Rate by Variant and Generation", width=500, height=300 + ) + .resolve_scale(color="independent") + ) + + # Combine both charts horizontally + combined_chart = ( + alt.hconcat(variant_chart, generation_chart) + .resolve_scale(color="shared", y="shared") + .properties(title="Cell Growth Rate Analysis - Comprehensive Comparison") + ) + + return combined_chart + + +# ------------------------------------- # + + +def plot( + params: dict[str, Any], + conn: DuckDBPyConnection, + history_sql: str, + config_sql: str, + success_sql: str, + sim_data_dict: dict[str, dict[int, str]], + validation_data_paths: list[str], + outdir: str, + variant_metadata: dict[str, dict[int, Any]], + variant_names: dict[str, str], +): + """Visualize cell growth rate metrics for multivariant E. coli simulation.""" + + # Load simulation data to get reference doubling time + with open_arbitrary_sim_data(sim_data_dict) as f: + sim_data = pickle.load(f) + + # Reference line for expected growth rate + sim_doubling_time = sim_data.doubling_time.asNumber() + ref_growth_rate = float(np.log(2)) / sim_doubling_time * 60 # Convert to 1/hour + + # Required columns for analysis + required_columns = [ + "time", + "variant", + "generation", + "listeners__mass__instantaneous_growth_rate", + ] + + # Build SQL query + all_columns = ", ".join(required_columns) + sql = f""" + SELECT {all_columns} + FROM ({history_sql}) + ORDER BY variant, generation, time + """ + + df = conn.sql(sql).pl() + + # Configuration parameters for filtering + target_variants = params.get("variant", None) # List of variant IDs or None for all + target_generations = params.get( + "generation", None + ) # List of generation IDs or None for all + show_reference = params.get( + "show_reference", True + ) # Whether to show reference line + print(f"[INFO] Show reference line: {show_reference}") + + # Filter by specified variants and generations + if target_variants is not None: + print(f"[INFO] Target variants: {target_variants}") + df = df.filter(pl.col("variant").is_in(target_variants)) + if target_generations is not None: + print(f"[INFO] Target generations: {target_generations}") + df = df.filter(pl.col("generation").is_in(target_generations)) + + # Time processing + df = df.with_columns((pl.col("time") / 60).alias("time_min")) + + # Calculate cell doubling time from instantaneous growth rate + if "listeners__mass__instantaneous_growth_rate" in df.columns: + # Convert from 1/sec to 1/hour + growth_rate = pl.col("listeners__mass__instantaneous_growth_rate") * 3600 + + # Sanitize doubling time values + growth_rate_valid_condition = ( + (growth_rate > 0) + & growth_rate.is_finite() + & (growth_rate < 2 * ref_growth_rate) + ) + + df = df.with_columns( + pl.when(growth_rate_valid_condition) + .then(growth_rate) + .otherwise(None) + .alias("growth_rate_per_hour") + ) + + # Specify variants for subplot creation + unique_variants = df["variant"].unique().sort().to_list() + + if not unique_variants: + # Create fallback chart if no data + fallback_df = pd.DataFrame( + {"message": ["No data available"], "x": [0], "y": [0]} + ) + fallback_chart = ( + alt.Chart(fallback_df) + .mark_text(size=20, color="red") + .encode(x="x:Q", y="y:Q", text="message:N") + .properties(width=600, height=400, title="No Data Available") + ) + out_path = os.path.join(outdir, "cell_growth_rate_report.html") + fallback_chart.save(out_path) + print(f"[ERROR] No data available. Saved fallback to: {out_path}") + return fallback_chart + + # Create subplot for each variant + charts = [] + + for variant_id in unique_variants: + variant_df = df.filter(pl.col("variant") == variant_id) + + if variant_df.height == 0: + continue + + # Get variant name for title + variant_name = variant_names.get(str(variant_id), f"Variant {variant_id}") + + # Create line chart for growth rate over time + base_chart = alt.Chart(variant_df) + + # Growth rate line + growth_line = base_chart.mark_line(point=True, strokeWidth=2).encode( + x=alt.X("time_min:Q", title="Time (min)"), + y=alt.Y("growth_rate_per_hour:Q", title="Growth Rate (1/hour)"), + color=alt.Color( + "generation:N", + legend=alt.Legend(title="Generation"), + scale=alt.Scale(scheme="category10"), + ), + tooltip=["time_min:Q", "growth_rate_per_hour:Q", "generation:N"], + ) + # Create chart layers list + chart_layers = [growth_line] + + # Conditionally add reference growth rate line + if show_reference: + ref_line = ( + alt.Chart( + pd.DataFrame( + { + "ref_rate": [ref_growth_rate], + "legend_label": ["Reference Simulation Growth Rate"], + } + ) + ) + .mark_rule(strokeDash=[5, 5], strokeWidth=2) + .encode( + y="ref_rate:Q", + color=alt.Color( + "legend_label:N", + scale=alt.Scale(range=["red"]), + legend=alt.Legend(title="Reference"), + ), + tooltip=alt.value(f"Reference rate: {ref_growth_rate:.3f} /hour"), + ) + ) + chart_layers.append(ref_line) + + # Combine layers + variant_chart = ( + alt.layer(*chart_layers) + .properties( + title=f"{variant_name} - Cell Growth Rate", width=500, height=300 + ) + .resolve_scale(color="independent", y="shared") + ) + + charts.append(variant_chart) + + # Arrange charts in a grid layout + if len(charts) == 1: + combined_chart = charts[0] + elif len(charts) == 2: + combined_chart = alt.hconcat(*charts) + else: + # For more than 2 charts, arrange in rows of 2 + rows = [] + for i in range(0, len(charts), 2): + if i + 1 < len(charts): + rows.append(alt.hconcat(charts[i], charts[i + 1])) + else: + rows.append(charts[i]) + combined_chart = alt.vconcat(*rows) + + # Add overall title for multiple plkots + if len(charts) > 1: + final_chart = combined_chart.resolve_scale( + x="shared", y="independent", color="independent" + ).properties(title="Cell Growth Rate Analysis - Multivariant Comparison") + else: + final_chart = combined_chart + + # Save the visualization + out_path = os.path.join(outdir, "multivariant_cell_growth_rate_report.html") + final_chart.save(out_path) + print(f"Saved cell growth rate visualization to: {out_path}") + + # Calculate averages + avg_by_variant = calculate_average_growth_rates( + df, variant_names, group_by_generation=False + ) + avg_by_variant_gen = calculate_average_growth_rates( + df, variant_names, group_by_generation=True + ) + + # Optional: Save results to CSV + avg_by_variant.write_csv( + os.path.join(outdir, "average_growth_rates_by_variant.csv") + ) + avg_by_variant_gen.write_csv( + os.path.join(outdir, "average_growth_rates_by_variant_generation.csv") + ) + + # Create and save comprehensive comparison chart + comprehensive_chart = create_growth_rate_comparison_chart( + avg_by_variant, avg_by_variant_gen, ref_growth_rate, show_reference + ) + comparison_path = os.path.join(outdir, "average_growth_rate_comparison.html") + comprehensive_chart.save(comparison_path) + print(f"Saved comprehensive growth rate comparison to: {comparison_path}") + + return final_chart diff --git a/ecoli/analysis/multivariant/cell_mass.py b/ecoli/analysis/multivariant/cell_mass.py new file mode 100644 index 000000000..3939adbbf --- /dev/null +++ b/ecoli/analysis/multivariant/cell_mass.py @@ -0,0 +1,175 @@ +""" +Plot absolue / normalized cell mass over time for multivariant simulation in vEcoli, and: +0. you can specify variants and generations to analyze, like: + \"multivariant\": { + ...... + \"cell_mass\": { + # Optional: specify variants and generations to visualize + # If not specified, all will be used + \"variant\": [0, 1, ...], + \"generation\": [1, 2, ....] + } + ...... + } +1. each variant has its own plot; +2. at each subplot, time is divided by generation id; + +It can also be used at multigeneration analysis. +""" + +import os +from typing import Any +import altair as alt +import polars as pl +import pandas as pd +from duckdb import DuckDBPyConnection + + +def plot( + params: dict[str, Any], + conn: DuckDBPyConnection, + history_sql: str, + config_sql: str, + success_sql: str, + sim_data_dict: dict[str, dict[int, str]], + validation_data_paths: list[str], + outdir: str, + variant_metadata: dict[str, dict[int, Any]], + variant_names: dict[str, str], +): + # Load data with required columns + required_columns = [ + "time", + "variant", + "lineage_seed", + "generation", + "listeners__mass__dry_mass", + "listeners__mass__dry_mass_fold_change", + ] + + sql = f""" + SELECT {", ".join(required_columns)} + FROM ({history_sql}) + ORDER BY variant, lineage_seed, generation, time + """ + + df = conn.sql(sql).pl() + + # Process time + df = df.with_columns( + [ + (pl.col("time") / 60).alias("time_min"), + ] + ) + + # Configuration parameters for filtering + target_variants = params.get("variant", None) # List of variant IDs or None for all + target_generations = params.get( + "generation", None + ) # List of generation IDs or None for all + + # Filter by specified variants and generations + if target_variants is not None: + print(f"[INFO] Target variants: {target_variants}") + df = df.filter(pl.col("variant").is_in(target_variants)) + if target_generations is not None: + print(f"[INFO] Target generations: {target_generations}") + df = df.filter(pl.col("generation").is_in(target_generations)) + + # Get variants and create plots + variants = df.select("variant").unique().to_series().to_list() + + # ----------------------------------------# + plots = [] + + # Create subplot for each variant + for variant in variants: + variant_df = df.filter(pl.col("variant") == variant).to_pandas() + variant_name = variant_names.get(variant, f"Variant {variant}") + + # Create base chart with line plots only + base = alt.Chart(variant_df).add_selection( + alt.selection_interval(bind="scales") + ) + + # Base encoding + tooltip_fields: list[str] = ["time_min:Q", "generation:N"] + base_encode = { + "x": alt.X("time_min:Q", title="Time (min)", scale=alt.Scale(nice=False)), + # Different generations with different colors + # Within same generation, color is the same + "color": alt.Color( + "generation:N", + legend=alt.Legend(title="Generation"), + scale=alt.Scale(scheme="category10"), + ), + } + + # Absolute dry mass plot + mass_plot = ( + base.mark_line(strokeWidth=2.5) + .encode( + x=base_encode["x"], + color=base_encode["color"], + tooltip=tooltip_fields + ["listeners__mass__dry_mass:Q"], + detail="lineage_seed:N", + y=alt.Y( + "listeners__mass__dry_mass:Q", + title="Dry Mass (fg)", + scale=alt.Scale(nice=False), + ), + ) + .properties( + width=400, height=200, title=f"{variant_name} - Absolute Dry Mass" + ) + ) + + # Normalized dry mass plot + norm_mass_plot = ( + base.mark_line(strokeWidth=2.5) + .encode( + x=base_encode["x"], + color=base_encode["color"], + tooltip=tooltip_fields + ["listeners__mass__dry_mass_fold_change:Q"], + detail="lineage_seed:N", + y=alt.Y( + "listeners__mass__dry_mass_fold_change:Q", + title="Normalized Dry Mass", + scale=alt.Scale(nice=False), + ), + ) + .properties( + width=400, height=200, title=f"{variant_name} - Normalized Dry Mass" + ) + ) + + # Add reference line at y=2 (doubling mass) + reference_line = ( + alt.Chart(pd.DataFrame({"y": [2]})) + .mark_rule(color="red", strokeDash=[5, 5], strokeWidth=1) + .encode(y="y:Q") + ) + + norm_mass_plot = norm_mass_plot + reference_line + + # Combine plots for this variant + variant_combined = ( + alt.hconcat(mass_plot, norm_mass_plot) + .resolve_scale(x="shared") + .properties(title=f"{variant_name} Cell Mass Analysis") + ) + + plots.append(variant_combined) + + # Create combined plot + final_plot = plots[0] if len(plots) == 1 else alt.vconcat(*plots) + final_plot = final_plot.resolve_scale(x="independent", y="independent").properties( + title="Multi-Variant Cell Mass Analysis" + ) + + # Save plot + out_path = os.path.join(outdir, "multivariant_cell_mass_report.html") + final_plot.save(out_path) + print(f"Saved multi-variant cell mass visualization to: {out_path}") + + return final_plot diff --git a/ecoli/analysis/multivariant/fba_flux.py b/ecoli/analysis/multivariant/fba_flux.py new file mode 100644 index 000000000..16a570b6a --- /dev/null +++ b/ecoli/analysis/multivariant/fba_flux.py @@ -0,0 +1,457 @@ +""" +Visualize FBA reaction fluxes over time for specified reactions with net flux calculation across multiple variants. + +Supports two visualization modes: +1. 'grid' mode: Each row represents a variant, each column represents a reaction +2. 'stacked' mode: Each reaction gets its own chart, variants shown as different colored lines + +You can specify the reactions and layout using parameters: + "fba_flux": { + # Required: specify BioCyc reaction IDs to visualize + "BioCyc_ID": ["Name1", "Name2", ...], + # Optional: specify variants to visualize + # If not specified, all variants will be used + "variant": [1, 2, ...], + # Optional: specify generations to visualize + # If not specified, all generations will be used + "generation": [1, 2, ...], + # Optional: specify layout mode ('grid' or 'stacked') + # Default: 'stacked' + "layout": "stacked" # or "grid" + } + +This script uses the base reaction ID to extended reaction mapping to efficiently +find forward and reverse reactions, then calculates net flux using SQL for +optimal memory usage and performance. +""" + +import altair as alt +import os +from typing import Any + +import polars as pl +from duckdb import DuckDBPyConnection +import pandas as pd + +from ecoli.library.parquet_emitter import field_metadata +from ecoli.analysis.utils import ( + create_base_to_extended_mapping, + build_flux_calculation_sql, +) + + +def plot( + params: dict[str, Any], + conn: DuckDBPyConnection, + history_sql: str, + config_sql: str, + success_sql: str, + sim_data_dict: dict[str, dict[int, str]], + validation_data_paths: list[str], + outdir: str, + variant_metadata: dict[str, dict[int, Any]], + variant_names: dict[str, str], +): + """Visualize FBA reaction net fluxes with configurable layout modes.""" + + # Get parameters + biocyc_ids = params.get("BioCyc_ID", []) + if not biocyc_ids: + print( + "[ERROR] No BioCyc_ID found in params. Please specify reaction IDs to visualize." + ) + return None + + if isinstance(biocyc_ids, str): + biocyc_ids = [biocyc_ids] + + # Get layout mode (default to 'stacked') + layout_mode = params.get("layout", "stacked").lower() + if layout_mode not in ["grid", "stacked"]: + print(f"[WARNING] Invalid layout mode '{layout_mode}'. Using 'stacked' mode.") + layout_mode = "stacked" + + print( + f"[INFO] Visualizing net fluxes for {len(biocyc_ids)} reactions: {biocyc_ids}" + ) + print(f"[INFO] Using layout mode: {layout_mode}") + + # Create base to extended reaction mapping + base_to_extended_mapping = create_base_to_extended_mapping(sim_data_dict) + if not base_to_extended_mapping: + print("[ERROR] Could not create base to extended reaction mapping") + return None + + # Load reaction IDs from config + try: + all_reaction_ids = field_metadata( + conn, config_sql, "listeners__fba_results__reaction_fluxes" + ) + print(f"[INFO] Total reactions in sim_data: {len(all_reaction_ids)}") + except Exception as e: + print(f"[ERROR] Error loading reaction IDs: {e}") + return None + + # Build SQL query for efficient flux calculation + flux_calculation_sql, valid_biocyc_ids = build_flux_calculation_sql( + biocyc_ids, base_to_extended_mapping, all_reaction_ids, history_sql + ) + + if not flux_calculation_sql or not valid_biocyc_ids: + print("[ERROR] Could not build flux calculation SQL") + return None + + print(f"[INFO] Processing {len(valid_biocyc_ids)} valid BioCyc IDs") + + # Execute the optimized SQL query + try: + df = conn.sql(flux_calculation_sql).pl() + print(f"[INFO] Loaded data with {df.height} time steps") + except Exception as e: + print(f"[ERROR] Error executing flux calculation SQL: {e}") + return None + + if df.is_empty(): + print("[ERROR] No data found") + return None + + # Filter by specified variants and generations if provided + target_variants = params.get("variant", None) + target_generations = params.get("generation", None) + + if target_variants is not None: + print(f"[INFO] Filtering for variants: {target_variants}") + df = df.filter(pl.col("variant").is_in(target_variants)) + + if target_generations is not None: + print(f"[INFO] Filtering for generations: {target_generations}") + df = df.filter(pl.col("generation").is_in(target_generations)) + + # Print variant and generation information + unique_variants = sorted(df["variant"].unique().to_list()) + unique_generations = sorted(df["generation"].unique().to_list()) + print(f"[INFO] Found {len(unique_variants)} variants: {unique_variants}") + print(f"[INFO] Found {len(unique_generations)} generations: {unique_generations}") + + # Print average net flux information + for biocyc_id in valid_biocyc_ids: + net_flux_col = f"{biocyc_id}_net_flux" + if net_flux_col in df.columns: + avg_flux = df[net_flux_col].mean() + print( + f"[INFO] Average net flux for {biocyc_id}: {avg_flux:.6f} mmol/gDW/hr" + ) + + # Calculate average net fluxes based on layout mode + if layout_mode == "grid": + # For grid mode: calculate averages by variant, generation, and reaction + avg_data = [] + for biocyc_id in valid_biocyc_ids: + net_flux_col = f"{biocyc_id}_net_flux" + if net_flux_col in df.columns: + variant_gen_avgs = df.group_by(["variant", "generation"]).agg( + pl.col(net_flux_col).mean().alias("avg_net_flux") + ) + + for row in variant_gen_avgs.iter_rows(named=True): + avg_data.append( + { + "biocyc_id": biocyc_id, + "variant": row["variant"], + "generation": row["generation"], + "avg_net_flux": row["avg_net_flux"], + } + ) + + avg_df = pl.DataFrame(avg_data) + else: + # For stacked mode: calculate averages by variant and reaction + avg_net_fluxes = {} + for biocyc_id in valid_biocyc_ids: + net_flux_col = f"{biocyc_id}_net_flux" + if net_flux_col in df.columns: + variant_avgs = df.group_by("variant").agg( + pl.col(net_flux_col).mean().alias("avg_net_flux") + ) + avg_net_fluxes[biocyc_id] = variant_avgs + + # Create visualization based on layout mode + if layout_mode == "grid": + combined_plot = create_grid_visualization( + df, avg_df, valid_biocyc_ids, unique_variants, unique_generations + ) + output_filename = "fba_net_flux_grid_analysis.html" + title_suffix = ( + f"{len(unique_variants)} Variants × {len(valid_biocyc_ids)} Reactions" + ) + else: + combined_plot = create_stacked_visualization( + df, avg_net_fluxes, valid_biocyc_ids + ) + output_filename = "fba_net_flux_stacked_analysis.html" + title_suffix = f"Multi-Variant Analysis ({len(valid_biocyc_ids)} reactions)" + + if combined_plot is None: + print("[ERROR] Failed to create visualization") + return None + + # Add overall title + combined_plot = combined_plot.properties( + title=alt.TitleParams( + text=f"FBA Net Flux Analysis: {title_suffix}", + fontSize=16, + anchor="start", + ) + ).resolve_scale(color="shared") + + # Save the plot + output_path = os.path.join(outdir, output_filename) + combined_plot.save(output_path) + print(f"[INFO] Saved visualization to: {output_path}") + + return combined_plot + + +def create_grid_visualization( + df, avg_df, valid_biocyc_ids, unique_variants, unique_generations +): + """Create grid layout visualization (rows = variants, columns = reactions).""" + + def create_subplot_chart(variant, biocyc_id): + """Create a single subplot for a specific variant-reaction combination.""" + net_flux_col = f"{biocyc_id}_net_flux" + + # Check if the column exists in dataframe + if net_flux_col not in df.columns: + print(f"[WARNING] Column {net_flux_col} not found in dataframe") + return None + + # Filter data for this variant and reaction + subplot_data = ( + df.filter(pl.col("variant") == variant) + .select(["time_min", "generation", net_flux_col]) + .filter(pl.col(net_flux_col).is_not_null()) + ) + + if subplot_data.height == 0: + print(f"[WARNING] No data for variant {variant}, reaction {biocyc_id}") + return None + + # Main line chart with generations as different colors + line_chart = ( + alt.Chart(subplot_data) + .mark_line(strokeWidth=1.5) + .encode( + x=alt.X( + "time_min:Q", + title="Time (min)" if variant == unique_variants[-1] else "", + ), + y=alt.Y( + f"{net_flux_col}:Q", + title="Net Flux (mmol/gDW/hr)" + if biocyc_id == valid_biocyc_ids[0] + else "", + ), + color=alt.Color( + "generation:N", + legend=alt.Legend(title="Generation") + if variant == unique_variants[0] + and biocyc_id == valid_biocyc_ids[0] + else None, + ), + tooltip=["time_min:Q", f"{net_flux_col}:Q", "generation:N"], + ) + ) + + # Average lines for each generation in this variant + avg_subset = avg_df.filter( + (pl.col("variant") == variant) & (pl.col("biocyc_id") == biocyc_id) + ) + + if avg_subset.height > 0: + avg_lines = ( + alt.Chart(avg_subset) + .mark_rule(strokeDash=[3, 3], strokeWidth=1.5) + .encode( + y=alt.Y("avg_net_flux:Q"), + color=alt.Color("generation:N", legend=None), + tooltip=[ + alt.Tooltip("avg_net_flux:Q", format=".4f"), + "generation:N", + ], + ) + ) + else: + avg_lines = alt.Chart().mark_point() # Empty chart + + # Zero line for reference + zero_line = ( + alt.Chart(pd.DataFrame({"zero": [0]})) + .mark_rule(color="gray", strokeDash=[1, 1], strokeWidth=1, opacity=0.5) + .encode(y=alt.Y("zero:Q")) + ) + + # Combine all elements + combined = (line_chart + avg_lines + zero_line).resolve_scale(color="shared") + + # Add title only for top row + if variant == unique_variants[0]: + combined = combined.properties(title=f"{biocyc_id}") + + combined = combined.properties(width=400, height=300) + + return combined + + def create_empty_subplot(): + """Create an empty placeholder subplot.""" + return ( + alt.Chart(pd.DataFrame({"x": [0], "y": [0]})) + .mark_point(opacity=0) + .properties(width=200, height=150) + ) + + # Create subplot grid: rows = variants, columns = reactions + subplot_grid = [] + + for variant in unique_variants: + variant_row = [] + for biocyc_id in valid_biocyc_ids: + subplot = create_subplot_chart(variant, biocyc_id) + if subplot is not None: + variant_row.append(subplot) + else: + # Create empty placeholder if no data + variant_row.append(create_empty_subplot()) + + if variant_row: + # Add variant label on the left + variant_label = ( + alt.Chart(pd.DataFrame({"label": [f"Variant {variant}"]})) + .mark_text( + align="center", baseline="middle", fontSize=12, fontWeight="bold" + ) + .encode(text="label:N") + .properties(width=160, height=300) + ) + + # Combine variant label with row of subplots + row_with_label = alt.hconcat(variant_label, *variant_row, spacing=10) + subplot_grid.append(row_with_label) + + if not subplot_grid: + print("[ERROR] No valid subplots could be created") + return None + + # Combine all rows + combined_plot = alt.vconcat(*subplot_grid, spacing=20) + return combined_plot + + +def create_stacked_visualization(df, avg_net_fluxes, valid_biocyc_ids): + """Create stacked layout visualization (one chart per reaction, variants as colored lines).""" + + def create_net_flux_chart(biocyc_id, net_flux_col, variant_avgs): + """Create a line chart for a single reaction net flux with average lines for each variant.""" + + # Check if the column exists in dataframe + if net_flux_col not in df.columns: + print(f"[WARNING] Column {net_flux_col} not found in dataframe") + return None + + # Select only the columns we need to minimize data transfer + data = df.select(["time_min", "generation", "variant", net_flux_col]) + + # Remove any null values + data = data.filter(pl.col(net_flux_col).is_not_null()) + + if data.height == 0: + print(f"[WARNING] No valid data for reaction {biocyc_id}") + return None + + # Main net flux line chart (different variants as different colored lines) + net_flux_chart = ( + alt.Chart(data) + .mark_line(strokeWidth=2) + .encode( + x=alt.X("time_min:Q", title="Time (min)"), + y=alt.Y(f"{net_flux_col}:Q", title="Net Flux (mmol/gDW/hr)"), + color=alt.Color("variant:N", legend=alt.Legend(title="Variant")), + tooltip=[ + "time_min:Q", + f"{net_flux_col}:Q", + "variant:N", + "generation:N", + ], + ) + ) + + # Create average lines for each variant + avg_line_data = [] + for row in variant_avgs.iter_rows(named=True): + variant_name = row["variant"] + avg_value = row["avg_net_flux"] + avg_line_data.append( + { + "variant": variant_name, + "avg_net_flux": avg_value, + "label": f"{variant_name} Avg: {avg_value:.4f}", + } + ) + + if avg_line_data: + avg_line_df = pd.DataFrame(avg_line_data) + + avg_lines = ( + alt.Chart(avg_line_df) + .mark_rule(strokeDash=[5, 5], strokeWidth=2) + .encode( + y=alt.Y("avg_net_flux:Q"), + color=alt.Color( + "variant:N", legend=None + ), # Use same color scale as main chart + tooltip=["label:N"], + ) + ) + else: + avg_lines = alt.Chart().mark_point() # Empty chart + + # Zero line for reference + zero_line = ( + alt.Chart(pd.DataFrame({"zero": [0], "label": ["Zero"]})) + .mark_rule(color="gray", strokeDash=[2, 2], strokeWidth=1, opacity=0.7) + .encode(y=alt.Y("zero:Q"), tooltip=["label:N"]) + ) + + # Combine net flux line, average lines, and zero line + combined_chart = ( + (net_flux_chart + avg_lines + zero_line) + .properties(title=f"Net Flux vs Time: {biocyc_id}", width=600, height=300) + .resolve_scale(y="shared", color="shared") + ) + + return combined_chart + + # Create charts for each reaction + charts = [] + + for biocyc_id in valid_biocyc_ids: + net_flux_col = f"{biocyc_id}_net_flux" + variant_avgs = avg_net_fluxes.get(biocyc_id) + if variant_avgs is not None: + chart = create_net_flux_chart(biocyc_id, net_flux_col, variant_avgs) + if chart is not None: + charts.append(chart) + + if not charts: + print("[ERROR] No valid charts could be created") + return None + + # Arrange charts vertically + if len(charts) == 1: + combined_plot = charts[0] + else: + combined_plot = alt.vconcat(*charts).resolve_scale( + x="shared", y="independent", color="shared" + ) + + return combined_plot diff --git a/ecoli/analysis/multivariant/fba_flux_process.py b/ecoli/analysis/multivariant/fba_flux_process.py new file mode 100644 index 000000000..7c2775df2 --- /dev/null +++ b/ecoli/analysis/multivariant/fba_flux_process.py @@ -0,0 +1,300 @@ +""" +Visualize FBA reaction fluxes over time for a biological process by aggregating multiple reactions +across multiple variants. + +Layout: The specified biological process will be visualized as a single line chart, +with different variants shown as different colored lines. The total flux is calculated +as the sum of net fluxes from all reactions specified in the BioCyc_ID list. + +You can specify the biological process to visualize using parameters in params: + "fba_process_flux": { + # Required: specify BioCyc IDs of reactions involved in the biological process + "BioCyc_ID": ["Name1", "Name2", ...], + # Required: name of the biological process for visualization + "process_name": "Glucose Transport", + # Optional: specify variants to visualize + # If not specified, all variants will be used + "variant": ["variant1", "variant2", ...], + # Optional: specify generation for visualization + # If not specified, all generations will be used + "generation": [1, 2, 3, ...] + } + +This script uses the base reaction ID to extended reaction mapping to efficiently +find forward and reverse reactions, then calculates total process flux using SQL for +optimal memory usage and performance. +""" + +import altair as alt +import os +from typing import Any + +import polars as pl +from duckdb import DuckDBPyConnection +import pandas as pd + +from ecoli.library.parquet_emitter import field_metadata +from ecoli.analysis.utils import ( + create_base_to_extended_mapping, + build_flux_calculation_sql, +) + + +def build_process_flux_sql( + biocyc_ids, base_to_extended_mapping, all_reaction_ids, history_sql, process_name +): + """ + Build SQL query to calculate total biological process flux by summing net fluxes from all reactions. + + Args: + biocyc_ids (list): List of BioCyc IDs (base reaction IDs) for the biological process + base_to_extended_mapping (dict): Mapping from base to extended reactions + all_reaction_ids (list): List of all reaction IDs from field_metadata + history_sql (str): SQL query for historical data + process_name (str): Name of the biological process + + Returns: + tuple: (sql_query, valid_biocyc_ids, process_flux_column_name) + """ + # First, use the existing function to get individual reaction flux calculations + individual_flux_sql, valid_biocyc_ids = build_flux_calculation_sql( + biocyc_ids, base_to_extended_mapping, all_reaction_ids, history_sql + ) + + if not individual_flux_sql or not valid_biocyc_ids: + print("[ERROR] Could not build individual flux calculations for process") + return None, [], "" + + # Extract just the flux calculations from the individual SQL + # Parse the SELECT clause to get individual flux expressions + flux_calculations = [] + for biocyc_id in valid_biocyc_ids: + flux_calculations.append(f'"{biocyc_id}_net_flux"') + + # Create safe process column name + safe_process_col = f'"{process_name.replace(" ", "_")}_total_flux"' + + # Build the process flux calculation SQL + # This sums all individual net fluxes to get total process flux + process_flux_expr = f"({' + '.join(flux_calculations)}) AS {safe_process_col}" + + # Build complete SQL query that first calculates individual fluxes, then sums them + sql = f""" + WITH individual_fluxes AS ( + {individual_flux_sql} + ) + SELECT + time, + generation, + variant, + time_min, + {process_flux_expr} + FROM individual_fluxes + ORDER BY variant, generation, time + """ + + return sql, valid_biocyc_ids, safe_process_col.strip('"') + + +def plot( + params: dict[str, Any], + conn: DuckDBPyConnection, + history_sql: str, + config_sql: str, + success_sql: str, + sim_data_dict: dict[str, dict[int, str]], + validation_data_paths: list[str], + outdir: str, + variant_metadata: dict[str, dict[int, Any]], + variant_names: dict[str, str], +): + """Visualize FBA biological process flux over time by aggregating multiple reactions across variants.""" + + # Get parameters + biocyc_ids = params.get("BioCyc_ID", []) + process_name = params.get("process_name", "Biological Process") + + if not biocyc_ids: + print( + "[ERROR] No BioCyc_ID found in params. Please specify reaction IDs for the biological process." + ) + return None + + if isinstance(biocyc_ids, str): + biocyc_ids = [biocyc_ids] + + print( + f"[INFO] Analyzing biological process '{process_name}' with {len(biocyc_ids)} reactions: {biocyc_ids}" + ) + + # Create base to extended reaction mapping + base_to_extended_mapping = create_base_to_extended_mapping(sim_data_dict) + if not base_to_extended_mapping: + print("[ERROR] Could not create base to extended reaction mapping") + return None + + # Load reaction IDs from config + try: + all_reaction_ids = field_metadata( + conn, config_sql, "listeners__fba_results__reaction_fluxes" + ) + print(f"[INFO] Total reactions in sim_data: {len(all_reaction_ids)}") + except Exception as e: + print(f"[ERROR] Error loading reaction IDs: {e}") + return None + + # Build SQL query for efficient process flux calculation + process_flux_sql, valid_biocyc_ids, process_flux_col = build_process_flux_sql( + biocyc_ids, + base_to_extended_mapping, + all_reaction_ids, + history_sql, + process_name, + ) + + if not process_flux_sql or not valid_biocyc_ids: + print("[ERROR] Could not build process flux calculation SQL") + return None + + print( + f"[INFO] Processing {len(valid_biocyc_ids)} valid BioCyc IDs for {process_name}" + ) + + # Execute the optimized SQL query + try: + df = conn.sql(process_flux_sql).pl() + print(f"[INFO] Loaded data with {df.height} time steps") + except Exception as e: + print(f"[ERROR] Error executing process flux calculation SQL: {e}") + return None + + if df.is_empty(): + print("[ERROR] No data found") + return None + + # Filter by specified variants and generations if provided + target_variants = params.get("variant", None) + target_generation = params.get("generation", None) + + if target_variants is not None: + print(f"[INFO] Filtering for variants: {target_variants}") + df = df.filter(pl.col("variant").is_in(target_variants)) + + if target_generation is not None: + print(f"[INFO] Filtering for generations: {target_generation}") + df = df.filter(pl.col("generation").is_in(target_generation)) + + # Print variant information + unique_variants = df["variant"].unique().to_list() + print(f"[INFO] Found {len(unique_variants)} variants: {unique_variants}") + + # Calculate average process flux for each variant + variant_averages = df.group_by("variant").agg( + pl.col(process_flux_col).mean().alias("avg_process_flux") + ) + + print(f"[INFO] Average {process_name} flux by variant:") + for row in variant_averages.iter_rows(named=True): + print(f" {row['variant']}: {row['avg_process_flux']:.6f} mmol/gDW/hr") + + # --------------------------- # + # Create visualization + def create_process_flux_chart(): + """Create a line chart for the biological process total flux.""" + + # Select only needed columns + chart_data = df.select(["time_min", "generation", "variant", process_flux_col]) + + # Remove any null values + chart_data = chart_data.filter(pl.col(process_flux_col).is_not_null()) + + if chart_data.height == 0: + print(f"[WARNING] No valid data for process {process_name}") + return None + + # Main process flux line chart + main_chart = ( + alt.Chart(chart_data) + .mark_line(strokeWidth=2) + .encode( + x=alt.X("time_min:Q", title="Time (min)"), + y=alt.Y(f"{process_flux_col}:Q", title="Process Flux (mmol/gDW/hr)"), + color=alt.Color("variant:N", legend=alt.Legend(title="Variant")), + tooltip=[ + "time_min:Q", + f"{process_flux_col}:Q", + "variant:N", + "generation:N", + ], + ) + ) + + # Create average lines for each variant + avg_line_data = [] + for row in variant_averages.iter_rows(named=True): + variant_name = row["variant"] + avg_value = row["avg_process_flux"] + avg_line_data.append( + { + "variant": variant_name, + "avg_flux": avg_value, + "label": f"{variant_name} Avg: {avg_value:.4f}", + } + ) + + avg_line_df = pd.DataFrame(avg_line_data) + + avg_lines = ( + alt.Chart(avg_line_df) + .mark_rule(strokeDash=[5, 5], strokeWidth=2) + .encode( + y=alt.Y("avg_flux:Q"), + color=alt.Color("variant:N", legend=None), + tooltip=["label:N"], + ) + ) + + # Zero line for reference + zero_line_data = pd.DataFrame({"zero": [0], "label": ["Zero"]}) + zero_line = ( + alt.Chart(zero_line_data) + .mark_rule(color="gray", strokeDash=[2, 2], strokeWidth=1, opacity=0.7) + .encode(y=alt.Y("zero:Q"), tooltip=["label:N"]) + ) + + # Combine all elements + combined_chart = ( + (main_chart + avg_lines + zero_line) + .properties( + title=f"Total Flux vs Time: {process_name}", width=800, height=400 + ) + .resolve_scale(y="shared", color="shared") + ) + + return combined_chart + + # --------------------------- # + + # Create the chart + chart = create_process_flux_chart() + if chart is None: + print("[ERROR] Could not create chart") + return None + + # Add overall title + final_chart = chart.properties( + title=alt.TitleParams( + text=f"Biological Process Flux Analysis: {process_name} ({len(valid_biocyc_ids)} reactions)", + fontSize=16, + anchor="start", + ) + ) + + # Save the plot + output_path = os.path.join( + outdir, f"fba_process_flux_{process_name.replace(' ', '_').lower()}.html" + ) + final_chart.save(output_path) + print(f"[INFO] Saved visualization to: {output_path}") + + return final_chart diff --git a/ecoli/analysis/single/escher_vis.py b/ecoli/analysis/single/escher_vis.py new file mode 100644 index 000000000..a0b80919b --- /dev/null +++ b/ecoli/analysis/single/escher_vis.py @@ -0,0 +1,361 @@ +""" +Utilize Escher API for visualizing fluxes in E. coli models. +This single-analysis method provides functionality to calculate generational average fluxes, +and then visualize them using Escher with BiGG ID mapping. + +You can specify the Escher map and the mapping CSV file in params: + "escher_vis": { + # Required: path to the CSV file with BiGG and BioCyc IDs + "csv_file_path": "path/to/bigg_biocyc_mapping.csv", + # Optional: specify Escher map name to visualize + # If not specified, defaults to 'e_coli_core.Core metabolism' + "map_name": "e_coli_core.Core metabolism", + } +""" + +import pandas as pd +import numpy as np +from duckdb import DuckDBPyConnection +from escher import Builder +import os +from typing import Dict, Any + +from ecoli.library.parquet_emitter import field_metadata +from ecoli.analysis.utils import ( + create_base_to_extended_mapping, + build_flux_calculation_sql, +) + + +class EscherFluxVisualizer: + """ + A class to calculate generational average fluxes from BioCyc IDs + and visualize them using Escher with BiGG ID mapping. + """ + + def __init__(self, csv_file_path: str): + """ + Initialize the visualizer with the CSV mapping file. + + Args: + csv_file_path (str): Path to CSV file with columns: + Original_Reaction_ID, BiGG_ID, BioCyc_ID + """ + self.csv_file_path = csv_file_path + self.mapping_df = None + self.average_fluxes = {} + + # Load the CSV mapping file + self._load_mapping_file() + + def _load_mapping_file(self): + """Load and validate the CSV mapping file.""" + try: + self.mapping_df = pd.read_csv(self.csv_file_path) + print(f"[INFO] Loaded mapping file with {len(self.mapping_df)} reactions") + + # Validate required columns + required_cols = ["Original_Reaction_ID", "BiGG_ID", "BioCyc_ID"] + missing_cols = [ + col for col in required_cols if col not in self.mapping_df.columns + ] + if missing_cols: + raise ValueError(f"Missing required columns: {missing_cols}") + + # Filter out rows with empty BioCyc_ID + initial_count = len(self.mapping_df) + self.mapping_df = self.mapping_df.dropna(subset=["BioCyc_ID"]) + self.mapping_df = self.mapping_df[ + self.mapping_df["BioCyc_ID"].str.strip() != "" + ] + filtered_count = len(self.mapping_df) + + print( + f"[INFO] Filtered to {filtered_count} reactions with valid BioCyc_ID " + f"(removed {initial_count - filtered_count} empty entries)" + ) + + except Exception as e: + print(f"[ERROR] Failed to load mapping file: {e}") + raise + + def calculate_generational_average_flux( + self, + conn: DuckDBPyConnection, + history_sql: str, + config_sql: str, + sim_data_dict: dict[str, dict[int, str]], + ) -> Dict[str, float]: + """ + Calculate generational average flux for all BioCyc IDs using optimized SQL approach. + + Args: + conn: DuckDB connection + history_sql: SQL for historical data + config_sql: SQL for configuration data + sim_data_dict: Simulation data dictionary for creating base-to-extended mapping + + Returns: + Dict[str, float]: Mapping of BioCyc_ID to average flux + """ + print("[INFO] Calculating generational average flux...") + + # Get unique BioCyc IDs from the mapping file + biocyc_ids = self.mapping_df["BioCyc_ID"].unique().tolist() + print(f"[INFO] Found {len(biocyc_ids)} unique BioCyc IDs to process") + + # Create base to extended reaction mapping + base_to_extended_mapping = create_base_to_extended_mapping(sim_data_dict) + if not base_to_extended_mapping: + print("[ERROR] Could not create base to extended reaction mapping") + return {} + + # Load reaction IDs from config + try: + all_reaction_ids = field_metadata( + conn, config_sql, "listeners__fba_results__reaction_fluxes" + ) + print(f"[INFO] Total reactions in sim_data: {len(all_reaction_ids)}") + except Exception as e: + print(f"[ERROR] Error loading reaction IDs: {e}") + return {} + + # Build SQL query for efficient flux calculation + flux_calculation_sql, valid_biocyc_ids = build_flux_calculation_sql( + biocyc_ids, base_to_extended_mapping, all_reaction_ids, history_sql + ) + + if not flux_calculation_sql or not valid_biocyc_ids: + print("[ERROR] Could not build flux calculation SQL") + return {} + + print(f"[INFO] Processing {len(valid_biocyc_ids)} valid BioCyc IDs") + + # Execute the optimized SQL query + try: + df = conn.sql(flux_calculation_sql).pl() + print(f"[INFO] Loaded flux data with {df.height} time steps") + except Exception as e: + print(f"[ERROR] Error executing flux calculation SQL: {e}") + return {} + + if df.height == 0: + print("[ERROR] No data found") + return {} + + # Calculate average flux for each BioCyc ID + average_fluxes = {} + + for biocyc_id in valid_biocyc_ids: + net_flux_col = f"{biocyc_id}_net_flux" + if net_flux_col in df.columns: + # Calculate average net flux using Polars + avg_net_flux = df[net_flux_col].mean() + average_fluxes[biocyc_id] = avg_net_flux + + print( + f"[INFO] Average net flux for {biocyc_id}: {avg_net_flux:.6f} mmol/gDW/hr" + ) + else: + print(f"[WARNING] Column {net_flux_col} not found in results") + + self.average_fluxes = average_fluxes + return average_fluxes + + def create_escher_flux_map( + self, + output_path: str = "flux_visualization.html", + map_name: str = "e_coli_core.Core metabolism", + ) -> Builder: + """ + Create Escher flux visualization using BiGG IDs and calculated fluxes. + + Args: + output_path (str): Path to save the HTML visualization + map_name (str): Escher map name to use + + Returns: + Builder: Escher Builder object + """ + if not self.average_fluxes: + print( + "[ERROR] No average fluxes calculated. Run calculate_generational_average_flux first." + ) + return None + + print( + f"[INFO] Creating Escher visualization with {len(self.average_fluxes)} flux values..." + ) + + # Create flux dictionary using BiGG IDs with proper data cleaning + bigg_flux_dict = {} + mapped_count = 0 + + for _, row in self.mapping_df.iterrows(): + biocyc_id = row["BioCyc_ID"] + bigg_id = str(row["BiGG_ID"]).strip() # Ensure string and strip whitespace + + if biocyc_id in self.average_fluxes: + flux_value = float(self.average_fluxes[biocyc_id]) # Ensure float type + + # Skip invalid values + if np.isnan(flux_value) or np.isinf(flux_value): + print( + f"[WARNING] Skipping invalid flux value for {biocyc_id}: {flux_value}" + ) + continue + + bigg_flux_dict[bigg_id] = flux_value + mapped_count += 1 + print(f"[INFO] Mapped {biocyc_id} -> {bigg_id}: {flux_value:.6f}") + + print( + f"[INFO] Successfully mapped {mapped_count} reactions for Escher visualization" + ) + + if not bigg_flux_dict: + print( + "[ERROR] No flux values mapped to BiGG IDs. Cannot create visualization." + ) + return None + + # Create Escher builder - using minimal initialization like original working version + try: + builder = Builder( + map_name=map_name, + reaction_data=bigg_flux_dict, + reaction_scale=[ + {"type": "min", "color": "#c8e6c9", "size": 12}, + {"type": "median", "color": "#81c784", "size": 20}, + {"type": "max", "color": "#388e3c", "size": 25}, + ], + reaction_no_data_color="#ddd", + reaction_no_data_size=8, + ) + + # Save to HTML file + builder.save_html(output_path) + print(f"[INFO] Escher visualization saved to: {output_path}") + + # Save as JSON for debugging + json_path = output_path.replace(".html", "_data.json") + import json + + with open(json_path, "w") as f: + json.dump(bigg_flux_dict, f, indent=2) + print(f"[DEBUG] Flux data also saved as JSON: {json_path}") + + return builder + + except Exception as e: + print(f"[ERROR] Failed to create Escher visualization: {e}") + return None + + def generate_flux_summary(self) -> pd.DataFrame: + """ + Generate a summary DataFrame with BioCyc_ID, BiGG_ID, and calculated fluxes. + + Returns: + pd.DataFrame: Summary of flux calculations + """ + if not self.average_fluxes: + print("[ERROR] No average fluxes calculated.") + return pd.DataFrame() + + summary_data = [] + + for _, row in self.mapping_df.iterrows(): + biocyc_id = row["BioCyc_ID"] + bigg_id = row["BiGG_ID"] + original_id = row["Original_Reaction_ID"] + + flux_value = self.average_fluxes.get(biocyc_id, np.nan) + + summary_data.append( + { + "Original_Reaction_ID": original_id, + "BiGG_ID": bigg_id, + "BioCyc_ID": biocyc_id, + "Average_Flux": flux_value, + "Has_Flux_Data": not np.isnan(flux_value), + } + ) + + summary_df = pd.DataFrame(summary_data) + + # Print summary statistics + total_reactions = len(summary_df) + reactions_with_data = summary_df["Has_Flux_Data"].sum() + print( + f"[INFO] Summary: {reactions_with_data}/{total_reactions} reactions have flux data" + ) + + return summary_df + + def save_flux_summary(self, output_path: str = "flux_summary.csv"): + """Save flux summary to CSV file.""" + summary_df = self.generate_flux_summary() + if not summary_df.empty: + summary_df.to_csv(output_path, index=False) + print(f"[INFO] Flux summary saved to: {output_path}") + else: + print("[WARNING] No summary data to save") + + +def plot( + params: dict[str, Any], + conn: DuckDBPyConnection, + history_sql: str, + config_sql: str, + success_sql: str, + sim_data_dict: dict[str, dict[int, str]], + validation_data_paths: list[str], + outdir: str, + variant_metadata: dict[str, dict[int, Any]], + variant_names: dict[str, str], +): + """Complete pipeline to run flux visualization with Escher.""" + + # Get parameters + csv_file_path = params.get("csv_file_path") + if not csv_file_path: + print("[ERROR] csv_file_path parameter is required") + return None + + escher_map_name = params.get("map_name", "e_coli_core.Core metabolism") + print(f"[INFO] Using Escher map: {escher_map_name}") + + # Create output directory if it doesn't exist + os.makedirs(outdir, exist_ok=True) + + try: + # Initialize visualizer + visualizer = EscherFluxVisualizer(csv_file_path) + + # Calculate generational average flux using the efficient SQL approach + average_fluxes = visualizer.calculate_generational_average_flux( + conn, history_sql, config_sql, sim_data_dict + ) + + if not average_fluxes: + print("[ERROR] No fluxes calculated. Cannot proceed with visualization.") + return None + + # Generate and save flux summary + summary_path = os.path.join(outdir, "escher_flux_summary.csv") + visualizer.save_flux_summary(summary_path) + + # Create Escher visualization + escher_path = os.path.join(outdir, "escher_flux_visualization.html") + builder = visualizer.create_escher_flux_map(escher_path, escher_map_name) + + if builder is None: + print("[ERROR] Failed to create Escher visualization") + return None + + print("[INFO] Escher flux visualization completed successfully!") + return visualizer, builder + + except Exception as e: + print(f"[ERROR] Failed to complete Escher visualization pipeline: {e}") + return None diff --git a/ecoli/analysis/single/fba_flux.py b/ecoli/analysis/single/fba_flux.py new file mode 100644 index 000000000..895153b01 --- /dev/null +++ b/ecoli/analysis/single/fba_flux.py @@ -0,0 +1,261 @@ +""" +Visualize FBA reaction net fluxes over time for a single generation with specified time window. + +You can specify the reactions to visualize using the 'BioCyc_ID' parameter in params: + "single_gen_fba_flux": { + "BioCyc_ID": ["Name1", "Name2", ...], + # Optional: specify time window to analyze + # If not specified, all time points will be used + "time_window": [start_time, end_time] # in seconds + } + +This script uses the base reaction ID to extended reaction mapping to efficiently +find forward and reverse reactions, then calculates net flux using SQL for +optimal memory usage and performance. + +For each reaction in params.BioCyc_ID, plot net flux vs time with each reaction having its own subplot. +""" + +import altair as alt +import os +from typing import Any + +import polars as pl +from duckdb import DuckDBPyConnection +import pandas as pd + +from ecoli.library.parquet_emitter import field_metadata +from ecoli.analysis.utils import ( + create_base_to_extended_mapping, + build_flux_calculation_sql, +) + + +def plot( + params: dict[str, Any], + conn: DuckDBPyConnection, + history_sql: str, + config_sql: str, + success_sql: str, + sim_data_dict: dict[str, dict[int, str]], + validation_data_paths: list[str], + outdir: str, + variant_metadata: dict[str, dict[int, Any]], + variant_names: dict[str, str], +): + """Visualize FBA reaction net fluxes over time for a single generation.""" + + # Get parameters + biocyc_ids = params.get("BioCyc_ID", []) + if not biocyc_ids: + print( + "[ERROR] No BioCyc_ID found in params. Please specify reaction IDs to visualize." + ) + return None + + if isinstance(biocyc_ids, str): + biocyc_ids = [biocyc_ids] + + print(f"[INFO] Analyzing for {len(biocyc_ids)} reactions: {biocyc_ids}") + + # Get time window (optional) + time_window = params.get("time_window", None) + if time_window is not None: + if len(time_window) != 2: + print( + "[ERROR] time_window must be a list of [start_time, end_time] in seconds." + ) + return None + start_time, end_time = time_window + print(f"[INFO] Time window: {start_time}s to {end_time}s") + else: + print("[INFO] Using full time range") + + # Create base to extended reaction mapping + base_to_extended_mapping = create_base_to_extended_mapping(sim_data_dict) + if not base_to_extended_mapping: + print("[ERROR] Could not create base to extended reaction mapping") + return None + + # Load reaction IDs from config + try: + all_reaction_ids = field_metadata( + conn, config_sql, "listeners__fba_results__reaction_fluxes" + ) + print(f"[INFO] Total reactions in sim_data: {len(all_reaction_ids)}") + except Exception as e: + print(f"[ERROR] Error loading reaction IDs: {e}") + return None + + # Build SQL query for efficient flux calculation + flux_calculation_sql, valid_biocyc_ids = build_flux_calculation_sql( + biocyc_ids, base_to_extended_mapping, all_reaction_ids, history_sql + ) + + if not flux_calculation_sql or not valid_biocyc_ids: + print("[ERROR] Could not build flux calculation SQL") + return None + + print(f"[INFO] Processing {len(valid_biocyc_ids)} valid BioCyc IDs") + + # Execute the optimized SQL query + try: + df = conn.sql(flux_calculation_sql).pl() + print(f"[INFO] Loaded data with {df.height} time steps") + except Exception as e: + print(f"[ERROR] Error executing flux calculation SQL: {e}") + return None + + if df.is_empty(): + print("[ERROR] No data found") + return None + + # Apply time window filter if specified + if time_window is not None: + start_time_min = start_time / 60 + end_time_min = end_time / 60 + df = df.filter( + (pl.col("time_min") >= start_time_min) + & (pl.col("time_min") <= end_time_min) + ) + print( + f"[INFO] Filtered to time window: {start_time_min:.2f} - {end_time_min:.2f} minutes" + ) + + if df.height == 0: + print("[ERROR] No data points in specified time window") + return None + + print(f"[INFO] Final dataset has {df.height} time steps") + + # Calculate statistics for each reaction + flux_stats = {} + for biocyc_id in valid_biocyc_ids: + net_flux_col = f"{biocyc_id}_net_flux" + if net_flux_col in df.columns: + stats = { + "avg": df[net_flux_col].mean(), + "min": df[net_flux_col].min(), + "max": df[net_flux_col].max(), + "std": df[net_flux_col].std(), + } + flux_stats[biocyc_id] = stats + + print( + f"[INFO] Net flux stats for {biocyc_id}: " + f"avg={stats['avg']:.6f}, min={stats['min']:.6f}, " + f"max={stats['max']:.6f} (mmol/gDW/hr)" + ) + + # --------------------------- # + + # Create visualization functions + def create_individual_net_flux_chart(biocyc_id, net_flux_col, stats): + """Create an individual chart for a single reaction net flux.""" + + # Check if the column exists in dataframe + if net_flux_col not in df.columns: + print(f"[WARNING] Column {net_flux_col} not found in dataframe") + return None + + # Select only the columns we need + data_pl = df.select(["time_min", net_flux_col]) + + # Remove any null values using Polars syntax + data_pl = data_pl.filter(pl.col(net_flux_col).is_not_null()) + + if data_pl.height == 0: + print(f"[WARNING] No valid data for reaction {biocyc_id}") + return None + + # Convert to pandas for Altair + data = data_pl.to_pandas() + + # Main net flux line chart + net_flux_chart = ( + alt.Chart(data) + .mark_line(strokeWidth=2, color="steelblue") + .encode( + x=alt.X("time_min:Q", title="Time (min)"), + y=alt.Y(f"{net_flux_col}:Q", title="Net Flux (mmol/gDW/hr)"), + tooltip=["time_min:Q", f"{net_flux_col}:Q"], + ) + ) + + # Average net flux horizontal line + avg_line = ( + alt.Chart( + pd.DataFrame( + { + "avg_net_flux": [stats["avg"]], + "label": [f"Avg: {stats['avg']:.4f}"], + } + ) + ) + .mark_rule(color="red", strokeDash=[5, 5], strokeWidth=2) + .encode(y=alt.Y("avg_net_flux:Q"), tooltip=["label:N"]) + ) + + # Zero line for reference + zero_line = ( + alt.Chart(pd.DataFrame({"zero": [0], "label": ["Zero"]})) + .mark_rule(color="gray", strokeDash=[2, 2], strokeWidth=1, opacity=0.7) + .encode(y=alt.Y("zero:Q"), tooltip=["label:N"]) + ) + + # Combine all elements + combined_chart = ( + (net_flux_chart + avg_line + zero_line) + .properties( + title=f"{biocyc_id} (Avg={stats['avg']:.4f}, Range=[{stats['min']:.4f}, {stats['max']:.4f}])", + width=600, + height=200, + ) + .resolve_scale(y="shared") + ) + + return combined_chart + + # --------------------------- # + + # Create individual charts for each reaction + charts = [] + + for biocyc_id in valid_biocyc_ids: + net_flux_col = f"{biocyc_id}_net_flux" + stats = flux_stats.get(biocyc_id, {}) + if stats: # Only create chart if we have valid stats + chart = create_individual_net_flux_chart(biocyc_id, net_flux_col, stats) + if chart is not None: + charts.append(chart) + + if not charts: + print("[ERROR] No valid charts could be created") + return None + + # Arrange charts vertically with shared x-axis + if len(charts) == 1: + combined_plot = charts[0] + else: + combined_plot = alt.vconcat(*charts).resolve_scale(x="shared", y="independent") + + # Add overall title + time_window_str = ( + f" (Time: {time_window[0] / 60:.1f}-{time_window[1] / 60:.1f} min)" + if time_window + else "" + ) + combined_plot = combined_plot.properties( + title=alt.TitleParams( + text=f"Single Generation FBA Net Fluxes{time_window_str}", + fontSize=16, + anchor="start", + ) + ) + + # Save the plot + output_path = os.path.join(outdir, "single_generation_fba_net_flux.html") + combined_plot.save(output_path) + print(f"[INFO] Saved visualization to: {output_path}") + + return combined_plot diff --git a/ecoli/analysis/single/fba_flux_heat_scatter.py b/ecoli/analysis/single/fba_flux_heat_scatter.py new file mode 100644 index 000000000..21be3f45a --- /dev/null +++ b/ecoli/analysis/single/fba_flux_heat_scatter.py @@ -0,0 +1,1004 @@ +""" +This script preprocesses FBA flux data by mapping extended reactions to base reactions, +computes net fluxes for base reactions (forward extended - reverse extended), +categorizes base reactions based on flux behavior (always positive, always negative, or oscillating), +and creates separate visualizations for each category. + +Modified to work with DuckDB connection and SQL queries instead of direct file loading. +All outputs are saved to outdir. +""" + +import pandas as pd +import numpy as np +import plotly.graph_objects as go +from plotly.subplots import make_subplots +from scipy.stats import gaussian_kde +import os +from typing import Any +from duckdb import DuckDBPyConnection + +from ecoli.library.parquet_emitter import field_metadata +from ecoli.analysis.utils import create_base_to_extended_mapping + + +def load_fba_data( + conn: DuckDBPyConnection, history_sql: str, config_sql: str, sim_data_dict: dict +) -> tuple[pd.DataFrame, dict]: + """ + Load FBA flux data using DuckDB connection and SQL queries. + + Parameters: + - conn: DuckDB connection + - history_sql: SQL query for historical data + - config_sql: SQL query for configuration data + - sim_data_dict: Dictionary with sim_data information + + Returns: + - df: DataFrame with flux data (time points x extended reactions) + - metadata: Dictionary with experiment metadata + """ + print("[INFO] Loading FBA flux data via SQL...") + + try: + # Create base to extended reaction mapping + base_to_extended_mapping = create_base_to_extended_mapping(sim_data_dict) + if not base_to_extended_mapping: + raise Exception("Could not create base to extended reaction mapping") + + # Load the reaction IDs from the config - this is the array that maps to flux matrix columns + reaction_ids = field_metadata( + conn, config_sql, "listeners__fba_results__reaction_fluxes" + ) + print(f"[INFO] Total reactions in sim_data: {len(reaction_ids)}") + + # Required columns for the query + required_columns = [ + "time", + "generation", + "listeners__fba_results__reaction_fluxes", + ] + + # Build SQL query + sql = f""" + SELECT {", ".join(required_columns)} + FROM ({history_sql}) + ORDER BY generation, time + """ + + # Execute query + df_pl = conn.sql(sql).pl() + + if df_pl.is_empty(): + raise Exception("No data found") + print(f"[INFO] Loaded data with {df_pl.height} time steps") + + # Extract flux matrix and convert to pandas DataFrame + flux_matrix = df_pl["listeners__fba_results__reaction_fluxes"].to_numpy() + flux_matrix = np.array([np.array(row) for row in flux_matrix]) + + # Create DataFrame with extended reactions + df = pd.DataFrame(flux_matrix, columns=reaction_ids) + + # Add time information + time_data = df_pl.select(["time"]).to_pandas() + df = pd.concat([time_data, df], axis=1) + + # Drop initial time point where time == 0 + df = df[df["time"] != 0].reset_index(drop=True) + + # Create metadata dictionary + metadata = { + "n_extended_reactions": len(reaction_ids), + "n_timepoints": len(df), + "extended_reaction_names": reaction_ids, + "base_to_extended_mapping": base_to_extended_mapping, + } + + print("[INFO] Successfully loaded data:") + print(f" - Time points: {len(df)}") + print(f" - Extended reactions: {len(reaction_ids)}") + print(f" - Base reaction mapping entries: {len(base_to_extended_mapping)}") + + return df, metadata + + except Exception as e: + print(f"[ERROR] Failed to load data: {str(e)}") + raise + + +def map_extended_to_base_reactions(extended_reactions, base_to_extended_mapping): + """ + Map extended reactions to base reactions and identify forward/reverse relationships. + + Parameters: + - extended_reactions: List of extended reaction names + - base_to_extended_mapping: Dict mapping base reaction ID to list of extended reaction names + + Returns: + - base_reaction_mapping: Dict with base reaction info including forward/reverse extended reactions + - extended_to_base_map: Dict mapping each extended reaction to its base reaction + """ + # Create reverse mapping from extended to base + extended_to_base_map = {} + for base_rxn, extended_list in base_to_extended_mapping.items(): + for extended_rxn in extended_list: + extended_to_base_map[extended_rxn] = base_rxn + + base_reaction_mapping = {} + + for extended_reaction in extended_reactions: + # Get base reaction name from mapping + base_reaction = extended_to_base_map.get(extended_reaction) + + if base_reaction is None: + print( + f"[WARNING] No base reaction found for extended reaction: {extended_reaction}" + ) + continue + + # Initialize base reaction entry if not exists + if base_reaction not in base_reaction_mapping: + base_reaction_mapping[base_reaction] = { + "forward_extended": [], + "reverse_extended": [], + "all_extended": [], + } + + # Determine if this is a forward or reverse extended reaction + if extended_reaction.endswith(" (reverse)"): + base_reaction_mapping[base_reaction]["reverse_extended"].append( + extended_reaction + ) + else: + base_reaction_mapping[base_reaction]["forward_extended"].append( + extended_reaction + ) + + base_reaction_mapping[base_reaction]["all_extended"].append(extended_reaction) + + print("[INFO] Base reaction mapping results:") + print(f" - Total base reactions: {len(base_reaction_mapping)}") + print(f" - Extended reactions mapped: {len(extended_to_base_map)}") + + # Print statistics about forward/reverse distributions + forward_only = sum( + 1 + for info in base_reaction_mapping.values() + if len(info["forward_extended"]) > 0 and len(info["reverse_extended"]) == 0 + ) + reverse_only = sum( + 1 + for info in base_reaction_mapping.values() + if len(info["forward_extended"]) == 0 and len(info["reverse_extended"]) > 0 + ) + both_directions = sum( + 1 + for info in base_reaction_mapping.values() + if len(info["forward_extended"]) > 0 and len(info["reverse_extended"]) > 0 + ) + + print(f" - Base reactions with forward extended only: {forward_only}") + print(f" - Base reactions with reverse extended only: {reverse_only}") + print( + f" - Base reactions with both forward and reverse extended: {both_directions}" + ) + + return base_reaction_mapping, extended_to_base_map + + +def compute_base_reaction_fluxes(flux_df, base_reaction_mapping): + """ + Compute base reaction fluxes by summing forward extended and subtracting reverse extended fluxes. + + Parameters: + - flux_df: DataFrame with extended reaction flux values + - base_reaction_mapping: Dict with base reaction info + + Returns: + - base_flux_df: DataFrame with base reaction net flux values + - base_reaction_details: Dict with detailed info about each base reaction + """ + base_flux_data = {} + base_reaction_details = {} + + for base_reaction, info in base_reaction_mapping.items(): + forward_extended = info["forward_extended"] + reverse_extended = info["reverse_extended"] + + # Sum forward extended fluxes + forward_flux = pd.Series(0.0, index=flux_df.index) + if forward_extended: + for ext_reaction in forward_extended: + if ext_reaction in flux_df.columns: + forward_flux += flux_df[ext_reaction] + + # Sum reverse extended fluxes + reverse_flux = pd.Series(0.0, index=flux_df.index) + if reverse_extended: + for ext_reaction in reverse_extended: + if ext_reaction in flux_df.columns: + reverse_flux += flux_df[ext_reaction] + + # Net flux = forward - reverse + net_flux = forward_flux - reverse_flux + base_flux_data[base_reaction] = net_flux + + # Store details for analysis + base_reaction_details[base_reaction] = { + "forward_extended": forward_extended, + "reverse_extended": reverse_extended, + "n_forward_extended": len(forward_extended), + "n_reverse_extended": len(reverse_extended), + "total_extended": len(info["all_extended"]), + } + + base_flux_df = pd.DataFrame(base_flux_data) + + print("[INFO] Base reaction flux computation results:") + print(f" - Base reactions computed: {len(base_flux_df.columns)}") + print(f" - Time points: {len(base_flux_df)}") + + return base_flux_df, base_reaction_details + + +def categorize_base_reactions_by_flux_behavior(base_flux_df, eps=1e-30): + """ + Categorize base reactions based on their flux behavior across time steps. + + Parameters: + - base_flux_df: DataFrame with base reaction net flux values + - eps: Small tolerance for zero comparison + + Returns: + - always_positive: List of base reactions that are always >= 0 and have max > 0 + - always_negative: List of base reactions that are always <= 0 and have max abs > 0 + - oscillating: List of base reactions that change sign + - always_zero: List of base reactions that are always zero + - base_reaction_categories: Dictionary with detailed categorization info + """ + always_positive = [] + always_negative = [] + oscillating = [] + always_zero = [] + base_reaction_categories = {} + + for base_reaction in base_flux_df.columns: + flux_values = base_flux_df[base_reaction].values + + # Check for positive, negative, and zero values + has_positive = np.any(flux_values > eps) + has_negative = np.any(flux_values < -eps) + has_zero = np.any(np.abs(flux_values) <= eps) + + min_flux = flux_values.min() + max_flux = flux_values.max() + max_abs_flux = np.max(np.abs(flux_values)) + + # Categorize based on behavior + if max_abs_flux <= eps: # All values are essentially zero + always_zero.append(base_reaction) + category = "always_zero" + elif not has_negative and has_positive: # All values >= -eps and has some > eps + always_positive.append(base_reaction) + category = "always_positive" + elif not has_positive and has_negative: # All values <= eps and has some < -eps + always_negative.append(base_reaction) + category = "always_negative" + elif has_positive and has_negative: # Has both positive and negative values + oscillating.append(base_reaction) + category = "oscillating" + else: + # This case should be covered by always_zero, but keep as safety net + always_zero.append(base_reaction) + category = "always_zero" + + base_reaction_categories[base_reaction] = { + "category": category, + "min_flux": min_flux, + "max_flux": max_flux, + "max_abs_flux": max_abs_flux, + "has_positive": has_positive, + "has_negative": has_negative, + "has_zero": has_zero, + } + + print("\n[INFO] Base reaction categorization by flux behavior:") + print(f" - Always positive (>= 0, max > 0): {len(always_positive)}") + print(f" - Always negative (<= 0, max abs > 0): {len(always_negative)}") + print(f" - Oscillating (changes sign): {len(oscillating)}") + print(f" - Always zero (max abs ≈ 0): {len(always_zero)}") + + return ( + always_positive, + always_negative, + oscillating, + always_zero, + base_reaction_categories, + ) + + +def print_base_reaction_category_summaries( + positive_df, negative_df, oscillating_df, always_zero_df +): + """Print summary information for each base reaction category.""" + + if len(positive_df) > 0: + print( + "\n[INFO] Always Positive Base Reactions (max flux > 0) - Top 5 most active (lowest zero ratio):" + ) + top_positive = positive_df.sort_values("zero_ratio").head(5) + for _, row in top_positive.iterrows(): + print( + f" {row['base_reaction']}: zero_ratio={row['zero_ratio']:.4f}, max_flux={row['max_flux']:.2e}, ext_reactions={row['total_extended']}" + ) + + if len(negative_df) > 0: + print( + "\n[INFO] Always Negative Base Reactions (max abs flux > 0) - Top 5 most active (lowest zero ratio):" + ) + top_negative = negative_df.sort_values("zero_ratio").head(5) + for _, row in top_negative.iterrows(): + print( + f" {row['base_reaction']}: zero_ratio={row['zero_ratio']:.4f}, max_abs_flux={row['max_abs_flux']:.2e}, ext_reactions={row['total_extended']}" + ) + + if len(oscillating_df) > 0: + print( + "\n[INFO] Oscillating Base Reactions - Top 5 most active (lowest zero ratio):" + ) + top_oscillating = oscillating_df.sort_values("zero_ratio").head(5) + for _, row in top_oscillating.iterrows(): + print( + f" {row['base_reaction']}: zero_ratio={row['zero_ratio']:.4f}, max_abs_flux={row['max_abs_flux']:.2e}, ext_reactions={row['total_extended']}" + ) + + if len(always_zero_df) > 0: + print("\n[INFO] Always Zero Base Reactions - First 5 examples:") + first_zero = always_zero_df.head(5) + for _, row in first_zero.iterrows(): + print( + f" {row['base_reaction']}: max_abs_flux={row['max_abs_flux']:.2e}, ext_reactions={row['total_extended']}" + ) + + +def create_base_reaction_flux_plots( + flux_data, direction, filename_suffix, epsilon_log, base_reaction_details, outdir +): + """Create heat scatter and simple scatter plots for base reactions in a given flux direction.""" + + if len(flux_data) == 0: + return + + # Prepare data for plotting + x = flux_data["log_max_abs_flux"] + y = flux_data["zero_ratio"] + + # Calculate point density using gaussian_kde for heat scatter plot + if len(flux_data) > 1: # Need at least 2 points for KDE + xy = np.vstack([x, y]) + density = gaussian_kde(xy)(xy) + else: + density = np.array([1.0]) # Single point gets density of 1 + + # Create hover text with base reaction information + hover_text = [] + simple_hover_text = [] + + for idx, row in flux_data.iterrows(): + base_reaction = row["base_reaction"] + details = base_reaction_details.get(base_reaction, {}) + + # Create extended reaction info + forward_ext = details.get("forward_extended", []) + reverse_ext = details.get("reverse_extended", []) + + ext_info = f"Forward: {len(forward_ext)} extended, Reverse: {len(reverse_ext)} extended" + if len(forward_ext) <= 3: + ext_info += ( + f"
Forward: {', '.join(forward_ext) if forward_ext else 'None'}" + ) + if len(reverse_ext) <= 3: + ext_info += ( + f"
Reverse: {', '.join(reverse_ext) if reverse_ext else 'None'}" + ) + + # For heat scatter (with density) + hover_text.append( + f"Base Reaction: {base_reaction}
" + + f"Extended Reactions: {ext_info}
" + + f"Category: {row['category']}
" + + f"Zero Ratio: {row['zero_ratio']:.4f}
" + + f"Max |Net Flux|: {row['max_abs_flux']:.2e}
" + + f"Min Net Flux: {row['min_flux']:.2e}
" + + f"Max Net Flux: {row['max_flux']:.2e}
" + + f"Log |Max Net Flux|: {row['log_max_abs_flux']:.2f}
" + + f"Point Density: {density[list(flux_data.index).index(idx)]:.6f}" + ) + + # For simple scatter (without density) + simple_hover_text.append( + f"Base Reaction: {base_reaction}
" + + f"Extended Reactions: {ext_info}
" + + f"Category: {row['category']}
" + + f"Zero Ratio: {row['zero_ratio']:.4f}
" + + f"Max |Net Flux|: {row['max_abs_flux']:.2e}
" + + f"Min Net Flux: {row['min_flux']:.2e}
" + + f"Max Net Flux: {row['max_flux']:.2e}
" + + f"Log |Max Net Flux|: {row['log_max_abs_flux']:.2f}" + ) + + # 1. HEAT SCATTER PLOT WITH MARGINAL HISTOGRAMS + # Create subplot with marginal histograms + fig_heat = make_subplots( + rows=2, + cols=2, + column_widths=[0.9, 0.1], + row_heights=[0.1, 0.9], + specs=[ + [{"secondary_y": False}, {"secondary_y": False}], + [{"secondary_y": False}, {"secondary_y": False}], + ], + vertical_spacing=0.05, + horizontal_spacing=0.05, + subplot_titles=("", "", "", ""), + ) + + # Main scatter plot (bottom left, row=2, col=1) + fig_heat.add_trace( + go.Scatter( + x=x, + y=y, + mode="markers", + marker=dict( + size=8, + color=density, + colorscale="Plasma", + opacity=0.8, + colorbar=dict( + title=dict(text="Point Density", font=dict(size=14)), + tickfont=dict(size=12), + thickness=15, + len=0.7, + x=1.02, # Position colorbar to the right + ), + line=dict(width=0.5, color="white"), + ), + text=hover_text, + hovertemplate="%{text}", + name="Base Reactions", + showlegend=False, + ), + row=2, + col=1, + ) + + # Top density curve (x-axis distribution, row=1, col=1) + if len(x) > 1: + # Create smooth density curve for x-axis + x_range = np.linspace(x.min(), x.max(), 250) + x_density = gaussian_kde(x) + x_density_values = x_density(x_range) + else: + # Single point case + x_range = np.array([x.iloc[0]]) + x_density_values = np.array([1.0]) + + fig_heat.add_trace( + go.Scatter( + x=x_range, + y=x_density_values, + mode="lines", + line=dict(color="steelblue", width=3), + fill="tozeroy", + fillcolor="rgba(70, 130, 180, 0.3)", + name="X Density", + showlegend=False, + ), + row=1, + col=1, + ) + + # Right density curve (y-axis distribution, row=2, col=2) + if len(y) > 1: + # Create smooth density curve for y-axis + y_range = np.linspace(y.min(), y.max(), 250) + y_density = gaussian_kde(y) + y_density_values = y_density(y_range) + else: + # Single point case + y_range = np.array([y.iloc[0]]) + y_density_values = np.array([1.0]) + + fig_heat.add_trace( + go.Scatter( + x=y_density_values, + y=y_range, + mode="lines", + line=dict(color="lightcoral", width=3), + fill="tozerox", + fillcolor="rgba(240, 128, 128, 0.3)", + name="Y Density", + showlegend=False, + ), + row=2, + col=2, + ) + + # Update layout for heat scatter with histograms + fig_heat.update_layout( + title=dict( + text=f"Heat Scatter Plot with Marginal Density Curves: {direction} Base Reaction Net Flux", + font=dict(size=18), + x=0.5, + xanchor="center", + ), + plot_bgcolor="white", + paper_bgcolor="white", + font=dict(family="Arial", size=12), + width=1000, + height=800, + margin=dict(l=80, r=120, t=100, b=80), + ) + + # Update axes for main plot + fig_heat.update_xaxes( + title=dict(text="log₁₀(ε + |Max Net Flux|)", font=dict(size=14)), + tickfont=dict(size=12), + gridcolor="rgba(128,128,128,0.2)", + gridwidth=1, + row=2, + col=1, + ) + fig_heat.update_yaxes( + title=dict(text="Zero Flux Ratio", font=dict(size=14)), + tickfont=dict(size=12), + gridcolor="rgba(128,128,128,0.2)", + gridwidth=1, + row=2, + col=1, + ) + + # Update axes for histograms (remove tick labels and titles) + fig_heat.update_xaxes(showticklabels=False, title="", row=1, col=1) + fig_heat.update_yaxes(showticklabels=False, title="", row=1, col=1) + fig_heat.update_xaxes(showticklabels=False, title="", row=2, col=2) + fig_heat.update_yaxes(showticklabels=False, title="", row=2, col=2) + + # Hide the top-right subplot + fig_heat.update_xaxes(visible=False, row=1, col=2) + fig_heat.update_yaxes(visible=False, row=1, col=2) + + # Add statistics annotation + stats_text = ( + f"Base Reactions ({direction}): {len(flux_data):,}
" + + f"ε = {epsilon_log:.0e}
" + + f"|Max Net Flux| Range: {flux_data['max_abs_flux'].min():.2e} to {flux_data['max_abs_flux'].max():.2e}
" + + f"Zero Ratio Range: {flux_data['zero_ratio'].min():.4f} to {flux_data['zero_ratio'].max():.4f}
" + + f"Extended Reactions Range: {flux_data['total_extended'].min()} to {flux_data['total_extended'].max()}" + ) + + if len(flux_data) > 1: + stats_text += f"
Density Range: {density.min():.2e} - {density.max():.2e}" + + fig_heat.add_annotation( + x=0.02, + y=0.48, + xref="paper", + yref="paper", + text=stats_text, + showarrow=False, + font=dict(size=11, color="black"), + bgcolor="rgba(255,255,255,0.9)", + bordercolor="rgba(128,128,128,0.5)", + borderwidth=1, + borderpad=10, + xanchor="left", + yanchor="top", + ) + + # 2. SIMPLE SCATTER PLOT (original version without histograms) + fig_simple = go.Figure() + + fig_simple.add_trace( + go.Scatter( + x=x, + y=y, + mode="markers", + marker=dict( + size=6, + color="steelblue", + opacity=0.7, + line=dict(width=0.5, color="white"), + ), + text=simple_hover_text, + hovertemplate="%{text}", + name="Base Reactions", + ) + ) + + fig_simple.update_layout( + title=dict( + text=f"Simple Scatter Plot: {direction} Base Reaction Net Flux - Zero Ratio vs |Max Net Flux|", + font=dict(size=18), + x=0.5, + xanchor="center", + ), + xaxis=dict( + title=dict(text="log₁₀(ε + |Max Net Flux|)", font=dict(size=14)), + tickfont=dict(size=12), + gridcolor="rgba(128,128,128,0.2)", + gridwidth=1, + ), + yaxis=dict( + title=dict(text="Zero Flux Ratio", font=dict(size=14)), + tickfont=dict(size=12), + gridcolor="rgba(128,128,128,0.2)", + gridwidth=1, + ), + plot_bgcolor="white", + paper_bgcolor="white", + font=dict(family="Arial", size=12), + width=900, + height=700, + margin=dict(l=80, r=80, t=100, b=80), + showlegend=False, + ) + + # Add statistics for simple plot + simple_stats_text = ( + f"Base Reactions ({direction}): {len(flux_data):,}
" + + f"ε = {epsilon_log:.0e}
" + + f"|Max Net Flux| Range: {flux_data['max_abs_flux'].min():.2e} to {flux_data['max_abs_flux'].max():.2e}
" + + f"Zero Ratio Range: {flux_data['zero_ratio'].min():.4f} to {flux_data['zero_ratio'].max():.4f}
" + + f"Extended Reactions Range: {flux_data['total_extended'].min()} to {flux_data['total_extended'].max()}" + ) + + fig_simple.add_annotation( + x=0.02, + y=0.48, + xref="paper", + yref="paper", + text=simple_stats_text, + showarrow=False, + font=dict(size=11, color="black"), + bgcolor="rgba(255,255,255,0.9)", + bordercolor="rgba(128,128,128,0.5)", + borderwidth=1, + borderpad=10, + xanchor="left", + yanchor="top", + ) + + # Save plots + heat_filename = os.path.join( + outdir, + f"heat_scatter_with_density_curves_base_reactions_{filename_suffix}.html", + ) + simple_filename = os.path.join( + outdir, f"simple_scatter_plot_base_reactions_{filename_suffix}.html" + ) + + fig_heat.write_html(heat_filename) + fig_simple.write_html(simple_filename) + + print(f"\n[INFO] {direction} base reaction flux plots saved:") + print(f" - {heat_filename}") + print(f" - {simple_filename}") + + +def save_base_reaction_results( + comprehensive_df, + oscillating_df, + always_zero_df, + base_reaction_details, + base_reaction_mapping, + active_base_flux_df, + metadata, + outdir, +): + """Save all base reaction results to CSV files with experiment metadata in outdir.""" + + # Create output directory if it doesn't exist + os.makedirs(outdir, exist_ok=True) + + # Create filename prefix + prefix = "base_reaction_analysis" + + # Save comprehensive base reaction metrics + comprehensive_filename = os.path.join(outdir, f"{prefix}_metrics.csv") + comprehensive_df.to_csv(comprehensive_filename, index=False) + print( + f"\n[INFO] Comprehensive base reaction metrics saved to '{comprehensive_filename}'" + ) + + # Save oscillating base reactions specifically + if len(oscillating_df) > 0: + oscillating_filename = os.path.join(outdir, f"{prefix}_oscillating.csv") + oscillating_df.to_csv(oscillating_filename, index=False) + print(f"[INFO] Oscillating base reactions saved to '{oscillating_filename}'") + else: + print("[INFO] No oscillating base reactions found.") + + # Save always zero base reactions specifically + if len(always_zero_df) > 0: + zero_filename = os.path.join(outdir, f"{prefix}_always_zero.csv") + always_zero_df.to_csv(zero_filename, index=False) + print(f"[INFO] Always zero base reactions saved to '{zero_filename}'") + else: + print("[INFO] No always zero base reactions found.") + + # Save base reaction mapping details + mapping_data = [] + for base_reaction, info in base_reaction_mapping.items(): + mapping_data.append( + { + "base_reaction": base_reaction, + "forward_extended": "; ".join(info["forward_extended"]), + "reverse_extended": "; ".join(info["reverse_extended"]), + "n_forward_extended": len(info["forward_extended"]), + "n_reverse_extended": len(info["reverse_extended"]), + "total_extended": len(info["all_extended"]), + } + ) + + mapping_df = pd.DataFrame(mapping_data) + mapping_filename = os.path.join(outdir, f"{prefix}_extended_mapping.csv") + mapping_df.to_csv(mapping_filename, index=False) + print( + f"[INFO] Base reaction to extended reaction mapping saved to '{mapping_filename}'" + ) + + # Save filtered active base reaction flux data + flux_filename = os.path.join(outdir, f"{prefix}_filtered_flux.csv") + active_base_flux_df.to_csv(flux_filename, index=False, encoding="utf-8-sig") + print(f"[INFO] Filtered active base reactions saved to '{flux_filename}'") + + # Save metadata + metadata_filename = os.path.join(outdir, f"{prefix}_metadata.csv") + metadata_for_csv = { + k: v for k, v in metadata.items() if not isinstance(v, (dict, list, np.ndarray)) + } # Only save simple types + metadata_for_csv["n_base_reactions"] = len(comprehensive_df) + metadata_df = pd.DataFrame([metadata_for_csv]) + metadata_df.to_csv(metadata_filename, index=False) + print(f"[INFO] Experiment metadata saved to '{metadata_filename}'") + + # Print detailed summary statistics by category + print("\n[INFO] Detailed Summary Statistics by Category for Base Reactions:") + + for category in [ + "always_positive", + "always_negative", + "oscillating", + "always_zero", + ]: + cat_df = comprehensive_df[comprehensive_df["category"] == category] + if len(cat_df) > 0: + print( + f"\n {category.replace('_', ' ').title()} Base Reactions ({len(cat_df)}):" + ) + print( + f" Zero ratio range: {cat_df['zero_ratio'].min():.4f} - {cat_df['zero_ratio'].max():.4f}" + ) + print( + f" |Max net flux| range: {cat_df['max_abs_flux'].min():.2e} - {cat_df['max_abs_flux'].max():.2e}" + ) + print( + f" Min net flux range: {cat_df['min_flux'].min():.2e} - {cat_df['min_flux'].max():.2e}" + ) + print( + f" Max net flux range: {cat_df['max_flux'].min():.2e} - {cat_df['max_flux'].max():.2e}" + ) + print( + f" Extended reactions per base: {cat_df['total_extended'].min()} - {cat_df['total_extended'].max()}" + ) + + # Count reactions with different flux behaviors + has_zero_count = cat_df["has_zero"].sum() + print(f" Base reactions with zero flux points: {has_zero_count}") + + # Extended reaction statistics + total_forward_ext = cat_df["n_forward_extended"].sum() + total_reverse_ext = cat_df["n_reverse_extended"].sum() + print(f" Total forward extended reactions: {total_forward_ext}") + print(f" Total reverse extended reactions: {total_reverse_ext}") + + print(f"\nTotal base reactions: {len(comprehensive_df)}") + print(f"Total extended reactions mapped: {metadata['n_extended_reactions']}") + + +def plot( + params: dict[str, Any], + conn: DuckDBPyConnection, + history_sql: str, + config_sql: str, + success_sql: str, + sim_data_dict: dict[str, dict[int, str]], + validation_data_paths: list[str], + outdir: str, + variant_metadata: dict[str, dict[int, Any]], + variant_names: dict[str, str], +): + """ + Preprocesses FBA flux data by mapping extended reactions to base reactions, + computes net fluxes for base reactions (forward extended - reverse extended), + categorizes base reactions based on flux behavior, and creates separate + visualizations for each category with marginal histograms. + + Parameters from params dict: + - zero_threshold: Fraction of zeros above which a base reaction is considered inactive (default: 0.999) + - eps: Small tolerance for zero comparison (default: 1e-30) + - epsilon_log: Small value to add to max flux for log transformation (default: 1e-30) + """ + + # Get parameters with defaults + zero_threshold = params.get("zero_threshold", 0.999) + eps = params.get("eps", 1e-30) + epsilon_log = params.get("epsilon_log", 1e-30) + + print("[INFO] Starting base reaction flux analysis...") + print( + f"[INFO] Parameters: zero_threshold={zero_threshold}, eps={eps}, epsilon_log={epsilon_log}" + ) + print(f"[INFO] Output directory: {outdir}") + + try: + # Load data + df, metadata = load_fba_data(conn, history_sql, config_sql, sim_data_dict) + + # Separate flux data (drop time columns) + extended_flux_df = df.drop(columns=["time"]) + + print( + f"[INFO] Original data: {len(extended_flux_df.columns)} extended reactions, {len(extended_flux_df)} time points" + ) + + # Map extended reactions to base reactions + base_reaction_mapping, extended_to_base_map = map_extended_to_base_reactions( + extended_flux_df.columns.tolist(), metadata["base_to_extended_mapping"] + ) + + # Compute base reaction fluxes + base_flux_df, base_reaction_details = compute_base_reaction_fluxes( + extended_flux_df, base_reaction_mapping + ) + + # Categorize base reactions by flux behavior + ( + always_positive, + always_negative, + oscillating, + always_zero, + base_reaction_categories, + ) = categorize_base_reactions_by_flux_behavior(base_flux_df, eps) + + # Compute zero-flux ratio per base reaction (using absolute values) + zero_counts = (np.abs(base_flux_df) <= eps).sum(axis=0) + zero_ratio = zero_counts / base_flux_df.shape[0] + + # Create comprehensive dataframe with all metrics + comprehensive_data = [] + for base_reaction in base_flux_df.columns: + cat_info = base_reaction_categories[base_reaction] + details = base_reaction_details[base_reaction] + comprehensive_data.append( + { + "base_reaction": base_reaction, + "category": cat_info["category"], + "zero_ratio": zero_ratio[base_reaction], + "min_flux": cat_info["min_flux"], + "max_flux": cat_info["max_flux"], + "max_abs_flux": cat_info["max_abs_flux"], + "log_max_abs_flux": np.log10( + epsilon_log + cat_info["max_abs_flux"] + ), + "has_positive": cat_info["has_positive"], + "has_negative": cat_info["has_negative"], + "has_zero": cat_info["has_zero"], + "n_forward_extended": details["n_forward_extended"], + "n_reverse_extended": details["n_reverse_extended"], + "total_extended": details["total_extended"], + } + ) + + comprehensive_df = pd.DataFrame(comprehensive_data) + + # Filter out base reactions with high zero_ratio + active_base_reactions = comprehensive_df[ + comprehensive_df["zero_ratio"] < zero_threshold + ]["base_reaction"].tolist() + active_base_flux_df = base_flux_df[active_base_reactions] + print( + f"\n[INFO] Base reactions remaining after filtering (zero_ratio < {zero_threshold}): {len(active_base_reactions)}" + ) + + # Create separate datasets for visualization + positive_df = comprehensive_df[ + (comprehensive_df["category"] == "always_positive") + & (comprehensive_df["max_flux"] > 0) + ].copy() + + negative_df = comprehensive_df[ + (comprehensive_df["category"] == "always_negative") + & (comprehensive_df["max_abs_flux"] > 0) + ].copy() + + oscillating_df = comprehensive_df[ + comprehensive_df["category"] == "oscillating" + ].copy() + always_zero_df = comprehensive_df[ + comprehensive_df["category"] == "always_zero" + ].copy() + + # Print filtering results + print("\n[INFO] Filtered datasets for visualization:") + print(f" - Always positive with max flux > 0: {len(positive_df)}") + print(f" - Always negative with max abs flux > 0: {len(negative_df)}") + print(f" - Oscillating: {len(oscillating_df)}") + print(f" - Always zero: {len(always_zero_df)}") + + # Print top/bottom reactions for each category + print_base_reaction_category_summaries( + positive_df, negative_df, oscillating_df, always_zero_df + ) + + # Create visualizations with marginal histograms + if len(positive_df) > 0: + create_base_reaction_flux_plots( + positive_df, + "Always Positive", + "always_positive", + epsilon_log, + base_reaction_details, + outdir, + ) + else: + print( + "\n[WARNING] No always positive base reactions found. Skipping positive plots." + ) + + if len(negative_df) > 0: + create_base_reaction_flux_plots( + negative_df, + "Always Negative", + "always_negative", + epsilon_log, + base_reaction_details, + outdir, + ) + else: + print( + "\n[WARNING] No always negative base reactions found. Skipping negative plots." + ) + + # Save results + save_base_reaction_results( + comprehensive_df, + oscillating_df, + always_zero_df, + base_reaction_details, + base_reaction_mapping, + active_base_flux_df, + metadata, + outdir, + ) + + print("\n[INFO] Base reaction flux preprocessing and visualization complete.") + print(f"[INFO] All files saved to directory: {outdir}") + + return ( + comprehensive_df, + active_base_flux_df, + oscillating_df, + always_zero_df, + base_reaction_details, + base_reaction_mapping, + metadata, + ) + + except Exception as e: + print(f"[ERROR] Analysis failed: {str(e)}") + import traceback + + traceback.print_exc() + return None diff --git a/ecoli/analysis/single/fba_flux_pca.py b/ecoli/analysis/single/fba_flux_pca.py new file mode 100644 index 000000000..59b9469ae --- /dev/null +++ b/ecoli/analysis/single/fba_flux_pca.py @@ -0,0 +1,380 @@ +""" +Visualize FBA reaction flux dynamics using PCA trajectory analysis. + +You can specify the reactions and time window using parameters: + "fba_flux_pca": { + # Required: specify BioCyc IDs of reactions to analyze + "BioCyc_ID": ["Name1", "Name2", ...], # Reactions of interest + # Optional: specify time window to analyze + # If not specified, all time points will be used + "time_window": [start_time, end_time] # in seconds + } + +This script uses the base reaction ID to extended reaction mapping to efficiently +find forward and reverse reactions, then calculates net flux using SQL for +optimal memory usage and performance. + +For the specified reactions, each timestep forms a vector of net flux values. +PCA is applied to reduce dimensionality to 2D and visualize the metabolic trajectory. +""" + +import altair as alt +import os +from typing import Any +import numpy as np +from sklearn.decomposition import PCA +from sklearn.preprocessing import StandardScaler + +import polars as pl +from duckdb import DuckDBPyConnection +import pandas as pd + +from ecoli.library.parquet_emitter import field_metadata +from ecoli.analysis.utils import ( + create_base_to_extended_mapping, + build_flux_calculation_sql, +) + + +def perform_pca_analysis(net_flux_df, valid_biocyc_ids, time_points): + """ + Perform PCA analysis on the net flux data. + + Args: + net_flux_df: Polars DataFrame with net flux columns + valid_biocyc_ids: List of BioCyc IDs (reaction names) + time_points: Time points corresponding to each row + + Returns: + dict: PCA results including transformed data, components, and explained variance + """ + + print("[INFO] Performing PCA analysis...") + + # Select only the net flux columns for PCA + net_flux_cols = [f"{biocyc_id}_net_flux" for biocyc_id in valid_biocyc_ids] + + # Check that all columns exist + missing_cols = [col for col in net_flux_cols if col not in net_flux_df.columns] + if missing_cols: + print(f"[ERROR] Missing net flux columns: {missing_cols}") + return None + + # Extract net flux matrix (n_timepoints x n_reactions) + net_flux_matrix = net_flux_df.select(net_flux_cols).to_numpy() + + print(f"[INFO] Net flux matrix shape for PCA: {net_flux_matrix.shape}") + + # Print basic statistics for each reaction + for i, biocyc_id in enumerate(valid_biocyc_ids): + flux_values = net_flux_matrix[:, i] + print( + f"[INFO] Net flux stats for {biocyc_id}: " + f"mean={flux_values.mean():.6f}, std={flux_values.std():.6f} (mmol/gDW/hr)" + ) + + # Standardize the data (important for PCA) + scaler = StandardScaler() + scaled_matrix = scaler.fit_transform(net_flux_matrix) + + # Perform PCA to reduce to 2 dimensions + pca = PCA(n_components=2) + pca_result = pca.fit_transform(scaled_matrix) + + # Calculate explained variance + explained_variance = pca.explained_variance_ratio_ + cumulative_variance = np.cumsum(explained_variance) + + print( + f"[INFO] PCA explained variance: PC1={explained_variance[0]:.3f}, PC2={explained_variance[1]:.3f}" + ) + print(f"[INFO] Cumulative explained variance: {cumulative_variance[1]:.3f}") + + # Get component loadings (contribution of each reaction to each PC) + components = pca.components_ + + # Create loading data for visualization + loadings_data = [] + for i, biocyc_id in enumerate(valid_biocyc_ids): + loadings_data.append( + { + "BioCyc_ID": biocyc_id, + "PC1_loading": components[0, i], + "PC2_loading": components[1, i], + "PC1_abs": abs(components[0, i]), + "PC2_abs": abs(components[1, i]), + } + ) + + pca_results = { + "pca_coordinates": pca_result, + "time_points": time_points, + "explained_variance": explained_variance, + "cumulative_variance": cumulative_variance[1], + "components": components, + "loadings_data": loadings_data, + "scaler": scaler, + "pca_model": pca, + } + + return pca_results + + +def plot( + params: dict[str, Any], + conn: DuckDBPyConnection, + history_sql: str, + config_sql: str, + success_sql: str, + sim_data_dict: dict[str, dict[int, str]], + validation_data_paths: list[str], + outdir: str, + variant_metadata: dict[str, dict[int, Any]], + variant_names: dict[str, str], +): + """Visualize FBA flux dynamics using PCA trajectory analysis.""" + + # Get parameters + biocyc_ids = params.get("BioCyc_ID", []) + if not biocyc_ids: + print( + "[ERROR] No BioCyc_ID found in params. Please specify reaction IDs for PCA analysis." + ) + return None + + if isinstance(biocyc_ids, str): + biocyc_ids = [biocyc_ids] + + # Get time window (optional) + time_window = params.get("time_window", None) + if time_window is not None: + if len(time_window) != 2: + print( + "[ERROR] time_window must be a list of [start_time, end_time] in seconds." + ) + return None + start_time, end_time = time_window + print(f"[INFO] Time window: {start_time}s to {end_time}s") + else: + print("[INFO] Using full time range") + + print(f"[INFO] PCA analysis with {len(biocyc_ids)} reactions: {biocyc_ids}") + + # Create base to extended reaction mapping + base_to_extended_mapping = create_base_to_extended_mapping(sim_data_dict) + if not base_to_extended_mapping: + print("[ERROR] Could not create base to extended reaction mapping") + return None + + # Load reaction IDs from config + try: + all_reaction_ids = field_metadata( + conn, config_sql, "listeners__fba_results__reaction_fluxes" + ) + print(f"[INFO] Total reactions in sim_data: {len(all_reaction_ids)}") + except Exception as e: + print(f"[ERROR] Error loading reaction IDs: {e}") + return None + + # Build SQL query for efficient flux calculation + flux_calculation_sql, valid_biocyc_ids = build_flux_calculation_sql( + biocyc_ids, base_to_extended_mapping, all_reaction_ids, history_sql + ) + + if not flux_calculation_sql or not valid_biocyc_ids: + print("[ERROR] Could not build flux calculation SQL") + return None + + if len(valid_biocyc_ids) < 2: + print("[ERROR] Need at least 2 valid reactions for PCA analysis") + return None + + print(f"[INFO] Processing {len(valid_biocyc_ids)} valid BioCyc IDs") + + # Execute the optimized SQL query + try: + df = conn.sql(flux_calculation_sql).pl() + print(f"[INFO] Loaded data with {df.height} time steps") + except Exception as e: + print(f"[ERROR] Error executing flux calculation SQL: {e}") + return None + + if df.is_empty(): + print("[ERROR] No data found") + return None + + # Apply time window filter if specified + if time_window is not None: + start_time_min = start_time / 60 + end_time_min = end_time / 60 + df = df.filter( + (pl.col("time_min") >= start_time_min) + & (pl.col("time_min") <= end_time_min) + ) + print( + f"[INFO] Filtered to time window: {start_time_min:.2f} - {end_time_min:.2f} minutes" + ) + + if df.height < 3: + print("[ERROR] Need at least 3 time points for meaningful PCA analysis") + return None + + print(f"[INFO] Final dataset has {df.height} time steps") + + # Get time points for trajectory + time_points = df.select("time_min").to_numpy().flatten() + + # Perform PCA analysis + pca_results = perform_pca_analysis(df, valid_biocyc_ids, time_points) + + if pca_results is None: + print("[ERROR] PCA analysis failed") + return None + + # Create visualization functions + def create_pca_trajectory_chart(pca_results): + """Create PCA trajectory visualization.""" + + # Prepare trajectory data + pca_coords = pca_results["pca_coordinates"] + time_points = pca_results["time_points"] + + trajectory_data = pd.DataFrame( + { + "PC1": pca_coords[:, 0], + "PC2": pca_coords[:, 1], + "Time_min": time_points, + "Point_idx": range(len(time_points)), + } + ) + + # Create trajectory line + trajectory_line = ( + alt.Chart(trajectory_data) + .mark_line(strokeWidth=2, color="steelblue") + .encode( + x=alt.X( + "PC1:Q", + title=f"PC1 ({pca_results['explained_variance'][0]:.1%} variance)", + ), + y=alt.Y( + "PC2:Q", + title=f"PC2 ({pca_results['explained_variance'][1]:.1%} variance)", + ), + order="Point_idx:O", + ) + ) + + # Add points with time color coding + trajectory_points = ( + alt.Chart(trajectory_data) + .mark_circle(size=60, stroke="white", strokeWidth=1) + .encode( + x=alt.X("PC1:Q"), + y=alt.Y("PC2:Q"), + color=alt.Color( + "Time_min:Q", title="Time (min)", scale=alt.Scale(scheme="viridis") + ), + tooltip=["Time_min:Q", "PC1:Q", "PC2:Q"], + ) + ) + + # Mark start and end points + start_point = ( + alt.Chart(trajectory_data.head(1)) + .mark_circle(size=100, color="green", stroke="white", strokeWidth=2) + .encode(x=alt.X("PC1:Q"), y=alt.Y("PC2:Q"), tooltip=alt.value("Start")) + ) + + end_point = ( + alt.Chart(trajectory_data.tail(1)) + .mark_circle(size=100, color="red", stroke="white", strokeWidth=2) + .encode(x=alt.X("PC1:Q"), y=alt.Y("PC2:Q"), tooltip=alt.value("End")) + ) + + # Combine all elements + pca_chart = ( + trajectory_line + trajectory_points + start_point + end_point + ).properties( + title=f"PCA Trajectory ({pca_results['cumulative_variance']:.1%} variance explained)", + width=500, + height=400, + ) + + return pca_chart + + def create_loadings_chart(pca_results): + """Create PCA loadings visualization.""" + + loadings_df = pd.DataFrame(pca_results["loadings_data"]) + + # Create biplot showing variable loadings + loadings_chart = ( + alt.Chart(loadings_df) + .mark_circle(size=100, stroke="black", strokeWidth=1) + .encode( + x=alt.X("PC1_loading:Q", title="PC1 Loading"), + y=alt.Y("PC2_loading:Q", title="PC2 Loading"), + color=alt.Color("BioCyc_ID:N", title="Reaction"), + tooltip=["BioCyc_ID:N", "PC1_loading:Q", "PC2_loading:Q"], + ) + ) + + # Add reaction labels + loadings_text = ( + alt.Chart(loadings_df) + .mark_text(dx=10, dy=-10, fontSize=10) + .encode( + x=alt.X("PC1_loading:Q"), y=alt.Y("PC2_loading:Q"), text="BioCyc_ID:N" + ) + ) + + # Add reference lines + zero_line_x = ( + alt.Chart(pd.DataFrame({"x": [0]})) + .mark_rule(color="gray", strokeDash=[2, 2]) + .encode(x="x:Q") + ) + zero_line_y = ( + alt.Chart(pd.DataFrame({"y": [0]})) + .mark_rule(color="gray", strokeDash=[2, 2]) + .encode(y="y:Q") + ) + + combined_loadings = ( + zero_line_x + zero_line_y + loadings_chart + loadings_text + ).properties( + title="PCA Loadings - Reaction Contributions", width=500, height=400 + ) + + return combined_loadings + + # Create visualizations + trajectory_chart = create_pca_trajectory_chart(pca_results) + loadings_chart = create_loadings_chart(pca_results) + + # Combine charts side by side + combined_plot = alt.hconcat(trajectory_chart, loadings_chart).resolve_scale( + color="independent" + ) + + # Add overall title + time_window_str = ( + f" (Time: {time_window[0] / 60:.1f}-{time_window[1] / 60:.1f} min)" + if time_window + else "" + ) + combined_plot = combined_plot.properties( + title=alt.TitleParams( + text=f"FBA Flux PCA Trajectory Analysis{time_window_str}", + fontSize=16, + anchor="start", + ) + ) + + # Save the plot + output_path = os.path.join(outdir, "fba_flux_pca_trajectory.html") + combined_plot.save(output_path) + print(f"[INFO] Saved PCA visualization to: {output_path}") + + return combined_plot diff --git a/ecoli/analysis/utils.py b/ecoli/analysis/utils.py new file mode 100644 index 000000000..c50acfc6e --- /dev/null +++ b/ecoli/analysis/utils.py @@ -0,0 +1,185 @@ +"""Helper functions for vEcoli analysis""" + +from typing import List, Tuple, Dict, Optional +import pickle +from collections import defaultdict + +from ecoli.library.parquet_emitter import open_arbitrary_sim_data + + +def categorize_reactions(extended_reactions: List[str]) -> Tuple[List[str], List[str]]: + """ + Categorize extended reactions into forward and reverse based on naming convention. + + Args: + extended_reactions: List of extended reaction names + + Returns: + Tuple of (forward_reactions, reverse_reactions) + """ + forward_reactions = [] + reverse_reactions = [] + + for rxn in extended_reactions: + if rxn.endswith(" (reverse)"): + reverse_reactions.append(rxn) + else: + forward_reactions.append(rxn) + + return forward_reactions, reverse_reactions + + +def build_flux_calculation_sql( + biocyc_ids: List[str], + base_to_extended_mapping: Dict[str, List[str]], + all_reaction_ids: List[str], + history_sql: str, +) -> Tuple[Optional[str], List[str]]: + """ + Build SQL query to calculate net fluxes directly in DuckDB for optimal performance. + + This function generates an optimized SQL query that calculates net flux + (forward_flux - reverse_flux) for each specified BioCyc ID using streaming + computation, avoiding the need to load large flux matrices into memory. + + Args: + biocyc_ids: List of BioCyc IDs (base reaction IDs) to analyze + base_to_extended_mapping: Mapping from base reaction ID to list of extended reaction names + all_reaction_ids: List of all reaction IDs from field_metadata (defines flux array order) + history_sql: SQL query string for accessing historical simulation data + + Returns: + Tuple of (sql_query_string, valid_biocyc_ids_list) + - sql_query_string: Complete SQL query for flux calculation, or None if no valid reactions + - valid_biocyc_ids_list: List of BioCyc IDs that have valid reactions found + + Example: + For a reaction with forward indices [3,4] and reverse indices [5,6], generates: + ``` + (fluxes[4] + fluxes[5]) - (fluxes[6] + fluxes[7]) AS "REACTION-ID_net_flux" + ``` + Note: Indices are converted from 0-based (Python) to 1-based (SQL) + """ + flux_calculations = [] + valid_biocyc_ids = [] + + for biocyc_id in biocyc_ids: + extended_reactions = base_to_extended_mapping.get(biocyc_id, []) + + if not extended_reactions: + print(f"[WARNING] No extended reactions found for BioCyc ID: {biocyc_id}") + continue + + # Separate forward and reverse reactions + forward_reactions, reverse_reactions = categorize_reactions(extended_reactions) + forward_indices = [] + reverse_indices = [] + + for rxn_name in forward_reactions: + try: + idx = all_reaction_ids.index(rxn_name) + forward_indices.append(idx) + except ValueError: + print(f"[WARNING] Reaction {rxn_name} not found in flux array") + + for rxn_name in reverse_reactions: + try: + idx = all_reaction_ids.index(rxn_name) + reverse_indices.append(idx) + except ValueError: + print(f"[WARNING] Reaction {rxn_name} not found in flux array") + + if not forward_indices and not reverse_indices: + print( + f"[WARNING] No valid reaction indices found for BioCyc ID: {biocyc_id}" + ) + continue + + print( + f"[INFO] {biocyc_id}: {len(forward_reactions)} forward, {len(reverse_reactions)} reverse reactions" + ) + + # Build SQL expression for net flux calculation + # Convert to 1-based indexing for SQL (DuckDB arrays are 1-indexed) + forward_terms = [] + if forward_indices: + forward_terms = [f"fluxes[{idx + 1}]" for idx in forward_indices] + + reverse_terms = [] + if reverse_indices: + reverse_terms = [f"fluxes[{idx + 1}]" for idx in reverse_indices] + + # Construct the net flux calculation expression + flux_expr_parts = [] + + # Add forward flux terms + if forward_terms: + if len(forward_terms) == 1: + flux_expr_parts.append(forward_terms[0]) + else: + flux_expr_parts.append(f"({' + '.join(forward_terms)})") + else: + flux_expr_parts.append("0") + + # Subtract reverse flux terms + if reverse_terms: + if len(reverse_terms) == 1: + flux_expr_parts.append(f" - {reverse_terms[0]}") + else: + flux_expr_parts.append(f" - ({' + '.join(reverse_terms)})") + + net_flux_expr = "".join(flux_expr_parts) + + # Escape column name with quotes to handle special characters like hyphens + safe_column_name = f'"{biocyc_id}_net_flux"' + flux_calculations.append(f"{net_flux_expr} AS {safe_column_name}") + valid_biocyc_ids.append(biocyc_id) + + if not flux_calculations: + print("[ERROR] No valid flux calculations could be built") + return None, [] + + # Build complete SQL query with CTE for better readability and performance + sql = f""" + WITH renamed AS ( + SELECT time, generation, variant, listeners__fba_results__reaction_fluxes AS fluxes + FROM ({history_sql}) + ) + SELECT + time, + generation, + variant, + time / 60.0 AS time_min, + {", ".join(flux_calculations)} + FROM renamed + ORDER BY generation, time + """ + + return sql, valid_biocyc_ids + + +def create_base_to_extended_mapping(sim_data_dict): + """ + Create reverse mapping from base reaction ID to extended fab reactions. + + Args: + sim_data_dict: Dictionary containing sim_data information + + Returns: + dict: Mapping from base reaction ID to list of extended reaction names + """ + # Load sim_data + with open_arbitrary_sim_data(sim_data_dict) as f: + sim_data = pickle.load(f) + reaction_ids = sim_data.process.metabolism.reaction_id_to_base_reaction_id + + if not reaction_ids: + print("[WARNING] Could not find reaction_id_to_base_reaction_id in sim_data") + return {} + + # Create reverse mapping + base_to_extended_mapping = defaultdict(list) + for extended_rxn, base_rxn_id in reaction_ids.items(): + base_to_extended_mapping[base_rxn_id].append(extended_rxn) + + return dict(base_to_extended_mapping) diff --git a/ecoli/experiments/ecoli_master_sim.py b/ecoli/experiments/ecoli_master_sim.py index 50931150c..0c6503ac5 100644 --- a/ecoli/experiments/ecoli_master_sim.py +++ b/ecoli/experiments/ecoli_master_sim.py @@ -23,9 +23,10 @@ import numpy as np from vivarium.core.engine import Engine +from vivarium.core.composer import deep_merge from vivarium.core.process import Process from vivarium.core.serialize import deserialize_value, serialize_value -from vivarium.library.dict_utils import deep_merge, deep_merge_check +from vivarium.library.dict_utils import deep_merge_check from vivarium.library.topology import inverse_topology from vivarium.library.topology import assoc_path from ecoli.library.logging_tools import write_json @@ -116,11 +117,10 @@ def get_git_diff() -> str: If that fails, tries to read the diff from source-info/git-diff.txt file. Raises an error if both methods fail. """ - # Try to run git command try: return ( subprocess.check_output(["git", "-C", CONFIG_DIR_PATH, "diff", "HEAD"]) - .decode("ascii") + .decode("utf-8") .strip() ) except (subprocess.CalledProcessError, FileNotFoundError): @@ -732,7 +732,45 @@ def build_ecoli(self): self.generated_initial_state, initial_environment ) - def save_states(self, daughter_outdir: str = ""): + def update_experiment(self, time_to_update: float = 0.0): + """ + Runs the E. coli simulation for a specified amount of time. If the + simulation reaches a division event and ``config['generations']`` is set, + it will save the daughter cell states to JSON files in the directory + specified by ``config['daughter_outdir']``. Also creates a file + ``division_time.sh`` that, when executed, sets the environment variable + ``division_time`` to the time at which division occurred (used in + Nextflow workflow runs). + """ + try: + self.ecoli_experiment.update(time_to_update) + except DivisionDetected: + state = self.ecoli_experiment.state.get_value(condition=not_a_process) + assert len(state["agents"]) == 2 + for i, agent_state in enumerate(state["agents"].values()): + prepare_save_state(agent_state) + daughter_path = os.path.join( + self.daughter_outdir, f"daughter_state_{i}.json" + ) + write_json(daughter_path, agent_state) + print( + f"Divided at t = {self.ecoli_experiment.global_time} after " + f"{self.ecoli_experiment.global_time - self.initial_global_time} sec." + ) + with open("division_time.sh", "w") as f: + f.write(f"export division_time={self.ecoli_experiment.global_time}") + # Tell Parquet emitter that simulation was successful + if isinstance(self.ecoli_experiment.emitter, ParquetEmitter): + self.ecoli_experiment.emitter.success = True + self.ecoli_experiment.emitter.finalize() + # Exit so that EcoliSim.run() does not raise TimeLimitError + sys.exit() + finally: + # Finish writing any buffered emits to Parquet files + if isinstance(self.ecoli_experiment.emitter, ParquetEmitter): + self.ecoli_experiment.emitter.finalize() + + def save_states(self): """ Runs the simulation while saving the states of specific timesteps to files named ``data/vivecoli_t{time}.json``. Invoked by @@ -740,12 +778,6 @@ def save_states(self, daughter_outdir: str = ""): if ``config['save'] == True``. State is saved as a JSON that can be reloaded into a simulation as described in :py:meth:`~ecoli.composites.ecoli_master.Ecoli.initial_state`. - - Args: - daughter_outdir: Location to write JSON files for daughter cell(s). - Only used if ``config`` contains ``generations`` key specifying - number of generations to simulate. Nextflow chains simulations - together by passing saved daughter states to new processes. """ for time in self.save_times: if time > self.max_duration: @@ -759,27 +791,7 @@ def save_states(self, daughter_outdir: str = ""): time_to_next_save = self.save_times[i] else: time_to_next_save = self.save_times[i] - self.save_times[i - 1] - try: - self.ecoli_experiment.update(time_to_next_save) - except DivisionDetected: - state = self.ecoli_experiment.state.get_value(condition=not_a_process) - assert len(state["agents"]) == 2 - for i, agent_state in enumerate(state["agents"].values()): - prepare_save_state(agent_state) - daughter_path = os.path.join( - daughter_outdir, f"daughter_state_{i}.json" - ) - write_json(daughter_path, agent_state) - print( - f"Divided at t = {self.ecoli_experiment.global_time} after" - f"{self.ecoli_experiment.global_time - self.initial_global_time} sec." - ) - with open("division_time.sh", "w") as f: - f.write(f"export division_time={self.ecoli_experiment.global_time}") - # Tell Parquet emitter that simulation was successful - if isinstance(self.ecoli_experiment.emitter, ParquetEmitter): - self.ecoli_experiment.emitter.success = True - sys.exit() + self.update_experiment(time_to_next_save) time_elapsed = self.save_times[i] state = self.ecoli_experiment.state.get_value(condition=not_a_process) if self.divide: @@ -791,10 +803,13 @@ def save_states(self, daughter_outdir: str = ""): print("Finished saving the state at t = " + str(time_elapsed)) time_remaining = self.max_duration - self.save_times[-1] if time_remaining: - self.ecoli_experiment.update(time_remaining) + self.update_experiment(time_remaining) def run(self): - """Create and run an EcoliSim experiment. + """Create and run an EcoliSim experiment. If the simulation reaches + the maximum duration specified by ``config['max_duration']``, it will + raise a :py:class:`~ecoli.experiments.ecoli_master_sim.TimeLimitError` + if ``config['fail_at_max_duration']`` is ``True``. .. WARNING:: Run :py:meth:`~ecoli.experiments.ecoli_master_sim.EcoliSim.build_ecoli` @@ -890,29 +905,9 @@ def run(self): # run the experiment if self.save: - self.save_states(self.daughter_outdir) + self.save_states() else: - try: - self.ecoli_experiment.update(self.max_duration) - except DivisionDetected: - state = self.ecoli_experiment.state.get_value(condition=not_a_process) - assert len(state["agents"]) == 2 - for i, agent_state in enumerate(state["agents"].values()): - prepare_save_state(agent_state) - daughter_path = os.path.join( - self.daughter_outdir, f"daughter_state_{i}.json" - ) - write_json(daughter_path, agent_state) - print( - f"Divided at t = {self.ecoli_experiment.global_time} after" - f"{self.ecoli_experiment.global_time - self.initial_global_time} sec." - ) - with open("division_time.sh", "w") as f: - f.write(f"export division_time={self.ecoli_experiment.global_time}") - # Tell Parquet emitter that simulation was successful - if isinstance(self.ecoli_experiment.emitter, ParquetEmitter): - self.ecoli_experiment.emitter.success = True - sys.exit() + self.update_experiment(self.max_duration) self.ecoli_experiment.end() if self.profile: report_profiling(self.ecoli_experiment.stats) diff --git a/ecoli/library/parameters.py b/ecoli/library/parameters.py index f13629abe..c3121eaf3 100644 --- a/ecoli/library/parameters.py +++ b/ecoli/library/parameters.py @@ -685,8 +685,9 @@ def main(): value_str = "{:.2e}".format(row.param.value.to(row.units).magnitude) if "e" in value_str: base, exponent = value_str.split("e") - exponent = exponent.strip("+-0") - base = base.strip("0") + exponent = int(exponent) + if "." in base: + base = base.rstrip("0").rstrip(".") if exponent: value_str = "%s \\times 10^{%s}" % (base, exponent) else: diff --git a/ecoli/library/parquet_emitter.py b/ecoli/library/parquet_emitter.py index 2257f2045..903fd3132 100644 --- a/ecoli/library/parquet_emitter.py +++ b/ecoli/library/parquet_emitter.py @@ -1,4 +1,3 @@ -import atexit import os from concurrent.futures import Future, ThreadPoolExecutor from typing import Any, Callable, cast, Mapping, Optional @@ -810,7 +809,12 @@ def submit(self, fn: Callable, *args, **kwargs) -> Future: class ParquetEmitter(Emitter): """ - Emit data to a Parquet dataset. + Emit data to a Parquet dataset. Note that :py:meth:`~.finalize` + must be explicitly called in a ``try...finally`` block around the call to + :py:meth:`vivarium.core.engine.Engine.update` to ensure that all buffered + emits are written to Parquet files when the simulation ends for any reason. + This is handled automatically in :py:class:`~ecoli.experiments.ecoli_master_sim.EcoliSim` + and :py:class:`~ecoli.processes.engine_process.EngineProcess` """ def __init__(self, config: dict[str, Any]) -> None: @@ -858,9 +862,8 @@ def __init__(self, config: dict[str, Any]) -> None: self.last_batch_future.set_result(None) # Set either by EcoliSim or by EngineProcess if sim reaches division self.success = False - atexit.register(self._finalize) - def _finalize(self): + def finalize(self): """Convert remaining batched emits to Parquet at sim shutdown and mark sim as successful if ``success`` flag was set. In vEcoli, this is done by :py:class:`~ecoli.experiments.ecoli_master_sim.EcoliSim` diff --git a/ecoli/library/test_parquet_emitter.py b/ecoli/library/test_parquet_emitter.py index ad0debec5..670306748 100644 --- a/ecoli/library/test_parquet_emitter.py +++ b/ecoli/library/test_parquet_emitter.py @@ -1,4 +1,3 @@ -import atexit import os import re import tempfile @@ -292,10 +291,6 @@ def test_initialization(self, temp_dir): emitter.partitioning_path = "path/to/output" assert emitter.out_uri == "gs://bucket/path" assert emitter.batch_size == 100 - # GCSFS uses asyncio and cannot schedule futures after interpreter shutdown - # so _finalize hook with raise an error that is ignored. Here we just - # unregister the hook to avoid cluttering the pytest log - atexit.unregister(emitter._finalize) def test_emit_configuration(self, temp_dir): """Test emitting configuration data.""" @@ -707,7 +702,7 @@ def test_extreme_data_types(self, temp_dir): ) def test_finalize(self, temp_dir): - """Test _finalize method that handles remaining data.""" + """Test finalize method that handles remaining data.""" emitter = ParquetEmitter({"out_dir": temp_dir}) emitter.experiment_id = "test_exp" emitter.partitioning_path = "path/to/output" @@ -729,8 +724,8 @@ def test_finalize(self, temp_dir): with patch( "ecoli.library.parquet_emitter.json_to_parquet" ) as mock_json_to_parquet: - # Test _finalize - emitter._finalize() + # Test finalize + emitter.finalize() # Verify json_to_parquet was called with truncated data mock_json_to_parquet.assert_called_once() @@ -741,7 +736,7 @@ def test_finalize(self, temp_dir): # Test success flag emitter.success = True - emitter._finalize() + emitter.finalize() assert os.path.exists( os.path.join( emitter.out_uri, @@ -939,8 +934,7 @@ def delayed_execution(): # Changed type for field2 to list so should fail with pytest.raises(pl.exceptions.InvalidOperationError): - emitter._finalize() - atexit.unregister(emitter._finalize) + emitter.finalize() # Cleanup the real executor real_executor.shutdown() diff --git a/ecoli/processes/engine_process.py b/ecoli/processes/engine_process.py index df45cbe21..c7bea28da 100644 --- a/ecoli/processes/engine_process.py +++ b/ecoli/processes/engine_process.py @@ -505,9 +505,14 @@ def next_update(self, timestep, states): self.emitter.emit(emit_config) # Run inner simulation for timestep. - self.sim.run_for(timestep) - if force_complete: - self.sim.complete() + try: + self.sim.run_for(timestep) + if force_complete: + self.sim.complete() + except (Exception, KeyboardInterrupt): + if isinstance(self.emitter, ParquetEmitter): + self.emitter.finalize() + raise update = {} @@ -520,7 +525,7 @@ def next_update(self, timestep, states): # Finalize emits before division if isinstance(self.emitter, ParquetEmitter): self.emitter.success = True - self.emitter._finalize() + self.emitter.finalize() # Perform division. daughters = [] daughter_states = self.sim.state.divide_value() diff --git a/pyproject.toml b/pyproject.toml index 4e094894b..f49e89e6f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,6 +50,10 @@ dependencies = [ "stochastic-arrow", "autograd", "xmltodict", + # Required to save Altair charts as PNG + "vl-convert-python", + "plotly", + "ipykernel" ] [project.optional-dependencies] diff --git a/runscripts/analysis.py b/runscripts/analysis.py index f1f53ceb6..ae03b123b 100644 --- a/runscripts/analysis.py +++ b/runscripts/analysis.py @@ -292,6 +292,17 @@ def main(): } variant_names = {config["experiment_id"][0]: variant_name} + # Save copy of config JSON with parameters for plots + metadata_path = os.path.join(os.path.abspath(config["outdir"]), "metadata.json") + if os.path.exists(metadata_path): + raise FileExistsError( + f"{metadata_path} already exists, indicating an analysis has " + f"been run with output directory {config['outdir']}. Please " + "delete/move it or specify a different output directory." + ) + with open(metadata_path, "w") as f: + json.dump(config, f) + # Establish DuckDB connection conn = create_duckdb_conn(out_uri, gcs_bucket, config.get("cpus")) history_sql, config_sql, success_sql = dataset_sql(out_uri, config["experiment_id"]) @@ -339,6 +350,8 @@ def main(): curr_outdir, ) else: + curr_outdir = os.path.abspath(config["outdir"]) + os.makedirs(curr_outdir, exist_ok=True) query_strings[duckdb_filter] = ( f"SELECT * FROM ({history_sql}) WHERE {duckdb_filter}", f"SELECT * FROM ({config_sql}) WHERE {duckdb_filter}", @@ -369,12 +382,6 @@ def main(): variant_names, ) - # Save copy of config JSON with parameters for plots - with open( - os.path.join(os.path.abspath(config["outdir"]), "metadata.json"), "w" - ) as f: - json.dump(config, f) - if __name__ == "__main__": main() diff --git a/runscripts/container/Dockerfile b/runscripts/container/Dockerfile index ff326c6c9..d18b93150 100644 --- a/runscripts/container/Dockerfile +++ b/runscripts/container/Dockerfile @@ -14,8 +14,11 @@ RUN echo "alias ls='ls --color=auto'" >> ~/.bashrc \ && echo "alias ll='ls -l'" >> ~/.bashrc \ && cp ~/.bashrc / +# gcc necessary for compiling C extensions in some Python packages. # procps necessary for `ps` command used by Nextflow to track processes. -RUN apt-get update && apt-get install -y gcc procps nano +# nano is a text editor for convenience. +# curl is necessary for authentication on Google Cloud VMs +RUN apt-get update && apt-get install -y gcc procps nano curl # Install the project into `/vEcoli` WORKDIR /vEcoli diff --git a/runscripts/container/Singularity b/runscripts/container/Singularity index 201536342..6df713271 100644 --- a/runscripts/container/Singularity +++ b/runscripts/container/Singularity @@ -27,6 +27,6 @@ From: ghcr.io/astral-sh/uv@sha256:1cc0392c8aad8026ef3922e3f997fff0f31e506b0ffe95 FILES_TO_ADD %post - apt-get update && apt-get install -y gcc procps nano + apt-get update && apt-get install -y gcc procps nano curl cd /vEcoli UV_CACHE_DIR="/vEcoli/.uv_cache" UV_COMPILE_BYTECODE=1 uv sync --frozen diff --git a/runscripts/debug/process_vulnerabilities.py b/runscripts/debug/process_vulnerabilities.py new file mode 100644 index 000000000..a98b8f453 --- /dev/null +++ b/runscripts/debug/process_vulnerabilities.py @@ -0,0 +1,185 @@ +""" +Process vulnerability data from comma-separated JSON format. + +This script processes JSON data containing package vulnerability information, +generates a markdown report with vulnerability details, and creates a shell +script to apply package upgrades using uv. + +Expected JSON format: +{ + "name": "package_name", + "version": "current_version", + "vulns": [ + { + "id": "VULNERABILITY_ID", + "fix_versions": ["fixed_version"], + "aliases": ["ALIAS1", "ALIAS2"], + "description": "Vulnerability description" + } + ] +} +""" + +import os +import json +import sys +from typing import Any +from datetime import datetime +import argparse +from packaging.version import Version + + +def generate_markdown_report(packages: list[dict[str, Any]]) -> tuple[str, list[str]]: + """Generate a markdown report of vulnerabilities and upgrades.""" + + markdown = f"""# Security Vulnerability Report + +Generated on: {datetime.now().strftime("%Y-%m-%d %H:%M:%S")} + +## Summary + +Found vulnerabilities in **{len(packages)}** packages requiring updates. + +## Package Upgrades Overview + +| Package | Current Version | Recommended Version | Vulnerabilities | +|---------|----------------|-------------------|-----------------| +""" + + # Package summary table + upgrade_commands = [] + + for pkg in packages: + name = pkg.get("name", "Unknown") + current_version = pkg.get("version", "Unknown") + vulns = pkg.get("vulns", []) + + # Find the highest fix version across all vulnerabilities + all_fix_versions = [] + vuln_count = len(vulns) + + for vuln in vulns: + fix_versions = vuln.get("fix_versions", []) + all_fix_versions.extend([Version(v) for v in fix_versions if v]) + + recommended_version = max(all_fix_versions) if all_fix_versions else "Unknown" + + markdown += f"| **{name}** | {current_version} | **{recommended_version}** | {vuln_count} |\n" + + if recommended_version != "Unknown": + upgrade_commands.append(f'-P "{name}=={recommended_version}"') + + markdown += "\n## Detailed Vulnerability Information\n\n" + + # Detailed vulnerability information + for pkg in packages: + name = pkg.get("name", "Unknown") + current_version = pkg.get("version", "Unknown") + vulns = pkg.get("vulns", []) + + markdown += f"### {name} (v{current_version})\n\n" + + if not vulns: + markdown += "No specific vulnerability details available.\n\n" + continue + + markdown += "| Vulnerability ID | Fix Versions | Aliases |\n" + markdown += "|-----------------|-------------|---------|\n" + + for vuln in vulns: + vuln_id = vuln.get("id", "Unknown") + fix_versions = ", ".join(vuln.get("fix_versions", ["Unknown"])) + aliases = ", ".join(vuln.get("aliases", [])) + + markdown += f"| {vuln_id} | {fix_versions} | {aliases} |\n" + + markdown += "\n" + + markdown += """ +## Recommended Actions + +1. Review the vulnerability details above. +2. Close and reopen this PR to trigger CI/CD tests. +3. Approve and merge the PR if everything looks good. + +--- +*This report was generated automatically. Please verify all upgrades before applying.* +""" + + return markdown, upgrade_commands + + +def main(): + parser = argparse.ArgumentParser( + description="Process vulnerability data and generate reports" + ) + parser.add_argument( + "input_file", + nargs="?", + help="Input file with comma-separated JSONs (default: stdin)", + ) + parser.add_argument( + "--output-md", default="vulnerability_report.md", help="Output markdown file" + ) + parser.add_argument( + "--output-sh", + default="apply_security_upgrades.sh", + help="Output shell script file", + ) + + args = parser.parse_args() + + # Read input data + if args.input_file: + try: + with open(args.input_file, "r") as f: + input_data = json.load(f) + except FileNotFoundError: + print(f"Error: File '{args.input_file}' not found.", file=sys.stderr) + sys.exit(1) + else: + print("Reading from stdin... (Ctrl+D to end)") + input_data = json.load(sys.stdin) + + if not input_data: + print("Error: No input data provided.", file=sys.stderr) + sys.exit(1) + + # Process the data + packages = [pkg for pkg in input_data["dependencies"] if pkg["vulns"]] + + print(f"📋 Detected {len(packages)} vulnerable packages") + + # Generate markdown report + markdown_content, upgrade_commands = generate_markdown_report(packages) + with open(args.output_md, "w") as f: + f.write(markdown_content) + print(f"📄 Markdown report saved to: {args.output_md}") + + # Generate shell script + script = f"""#!/bin/bash +# Security upgrade script +# Generated automatically from vulnerability analysis + +set -e # Exit on any error + +echo "🔒 Applying security upgrades..." +echo "This script will upgrade vulnerable packages using uv lock --upgrade-package" +uv lock {" ".join(upgrade_commands)} + +echo "✅ All security upgrades completed successfully!" +""" + with open(args.output_sh, "w") as f: + f.write(script) + + # Make script executable + os.chmod(args.output_sh, 0o755) + print(f"🔧 Shell script saved to: {args.output_sh} (executable)") + + print("\n✅ Processing complete!") + print(f"Review the report: {args.output_md}") + print(f"Apply upgrades: ./{args.output_sh}") + + +if __name__ == "__main__": + main() diff --git a/runscripts/jenkins/configs/ecoli-glucose-minimal.json b/runscripts/jenkins/configs/ecoli-glucose-minimal.json index a90b7983d..747022a3a 100644 --- a/runscripts/jenkins/configs/ecoli-glucose-minimal.json +++ b/runscripts/jenkins/configs/ecoli-glucose-minimal.json @@ -12,7 +12,8 @@ "analysis_options": { "single": {"mass_fraction_summary": {}}, "multiseed": {"protein_counts_validation": {}}, - "multivariant": {"doubling_time_hist": {"skip_n_gens": 0}, "doubling_time_line": {}} + "multivariant": {"doubling_time_hist": {"skip_n_gens": 0}, "doubling_time_line": {}, "cell_mass": {}}, + "multigeneration": {"replication": {}, "ribosome_usage": {}} }, "sherlock": { "container_image": "container-image", diff --git a/runscripts/sim.py b/runscripts/sim.py index 3f602cdde..85c2b45ac 100644 --- a/runscripts/sim.py +++ b/runscripts/sim.py @@ -1,4 +1,5 @@ import os +import signal import sys import subprocess @@ -16,8 +17,14 @@ def main(): # Forward all arguments cmd = [sys.executable, script_path] + sys.argv[1:] # Execute and forward exit code - result = subprocess.run(cmd) - return result.returncode + proc = subprocess.Popen(cmd) + try: + proc.wait() + # Ensure emits are finalized even if wrapper is interrupted + finally: + proc.send_signal(signal.SIGINT) + proc.wait() + return proc.returncode if __name__ == "__main__": diff --git a/runscripts/workflow.py b/runscripts/workflow.py index 201304cfa..c906b57df 100644 --- a/runscripts/workflow.py +++ b/runscripts/workflow.py @@ -409,7 +409,7 @@ def main(): else: out_uri = config["emitter_arg"]["out_uri"] parsed_uri = parse.urlparse(out_uri) - if parsed_uri.schema not in ("local", "file") and not FSSPEC_AVAILABLE: + if parsed_uri.scheme not in ("local", "file") and not FSSPEC_AVAILABLE: raise RuntimeError( f"URI '{out_uri}' specified but fsspec is not available. " "Install fsspec or provide a local URI/out directory." @@ -423,14 +423,27 @@ def main(): config["lineage_seed"] = random.randint(0, 2**31 - 1) filesystem, outdir = parse_uri(out_uri) outdir = os.path.join(outdir, experiment_id, "nextflow") + exp_outdir = os.path.dirname(outdir) out_uri = os.path.join(out_uri, experiment_id, "nextflow") repo_dir = os.path.dirname(os.path.dirname(__file__)) local_outdir = os.path.join(repo_dir, "nextflow_temp", experiment_id) os.makedirs(local_outdir, exist_ok=True) if filesystem is None: - os.makedirs(outdir, exist_ok=args.resume) + if os.path.exists(exp_outdir) and not args.resume: + raise RuntimeError( + f"Output directory already exists: {exp_outdir}. " + "Please use a different experiment ID or output directory. " + "Alternatively, move, delete, or rename the existing directory." + ) + os.makedirs(outdir, exist_ok=True) else: - filesystem.makedirs(outdir, exist_ok=args.resume) + if filesystem.exists(exp_outdir) and not args.resume: + raise RuntimeError( + f"Output directory already exists: {exp_outdir}. " + "Please use a different experiment ID or output directory. " + "Alternatively, move, delete, or rename the existing directory." + ) + filesystem.makedirs(outdir, exist_ok=True) temp_config_path = f"{local_outdir}/workflow_config.json" final_config_path = os.path.join(outdir, "workflow_config.json") final_config_uri = os.path.join(out_uri, "workflow_config.json") diff --git a/uv.lock b/uv.lock index 18bbc4489..fa0e82969 100644 --- a/uv.lock +++ b/uv.lock @@ -26,7 +26,7 @@ wheels = [ [[package]] name = "aiohttp" -version = "3.12.13" +version = "3.12.14" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "aiohappyeyeballs" }, @@ -37,37 +37,38 @@ dependencies = [ { name = "propcache" }, { name = "yarl" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/42/6e/ab88e7cb2a4058bed2f7870276454f85a7c56cd6da79349eb314fc7bbcaa/aiohttp-3.12.13.tar.gz", hash = "sha256:47e2da578528264a12e4e3dd8dd72a7289e5f812758fe086473fab037a10fcce", size = 7819160, upload-time = "2025-06-14T15:15:41.354Z" } +sdist = { url = "https://files.pythonhosted.org/packages/e6/0b/e39ad954107ebf213a2325038a3e7a506be3d98e1435e1f82086eec4cde2/aiohttp-3.12.14.tar.gz", hash = "sha256:6e06e120e34d93100de448fd941522e11dafa78ef1a893c179901b7d66aa29f2", size = 7822921, upload-time = "2025-07-10T13:05:33.968Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/b4/6a/ce40e329788013cd190b1d62bbabb2b6a9673ecb6d836298635b939562ef/aiohttp-3.12.13-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:0aa580cf80558557285b49452151b9c69f2fa3ad94c5c9e76e684719a8791b73", size = 700491, upload-time = "2025-06-14T15:14:00.048Z" }, - { url = "https://files.pythonhosted.org/packages/28/d9/7150d5cf9163e05081f1c5c64a0cdf3c32d2f56e2ac95db2a28fe90eca69/aiohttp-3.12.13-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:b103a7e414b57e6939cc4dece8e282cfb22043efd0c7298044f6594cf83ab347", size = 475104, upload-time = "2025-06-14T15:14:01.691Z" }, - { url = "https://files.pythonhosted.org/packages/f8/91/d42ba4aed039ce6e449b3e2db694328756c152a79804e64e3da5bc19dffc/aiohttp-3.12.13-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:78f64e748e9e741d2eccff9597d09fb3cd962210e5b5716047cbb646dc8fe06f", size = 467948, upload-time = "2025-06-14T15:14:03.561Z" }, - { url = "https://files.pythonhosted.org/packages/99/3b/06f0a632775946981d7c4e5a865cddb6e8dfdbaed2f56f9ade7bb4a1039b/aiohttp-3.12.13-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:29c955989bf4c696d2ededc6b0ccb85a73623ae6e112439398935362bacfaaf6", size = 1714742, upload-time = "2025-06-14T15:14:05.558Z" }, - { url = "https://files.pythonhosted.org/packages/92/a6/2552eebad9ec5e3581a89256276009e6a974dc0793632796af144df8b740/aiohttp-3.12.13-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:d640191016763fab76072c87d8854a19e8e65d7a6fcfcbf017926bdbbb30a7e5", size = 1697393, upload-time = "2025-06-14T15:14:07.194Z" }, - { url = "https://files.pythonhosted.org/packages/d8/9f/bd08fdde114b3fec7a021381b537b21920cdd2aa29ad48c5dffd8ee314f1/aiohttp-3.12.13-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4dc507481266b410dede95dd9f26c8d6f5a14315372cc48a6e43eac652237d9b", size = 1752486, upload-time = "2025-06-14T15:14:08.808Z" }, - { url = "https://files.pythonhosted.org/packages/f7/e1/affdea8723aec5bd0959171b5490dccd9a91fcc505c8c26c9f1dca73474d/aiohttp-3.12.13-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8a94daa873465d518db073bd95d75f14302e0208a08e8c942b2f3f1c07288a75", size = 1798643, upload-time = "2025-06-14T15:14:10.767Z" }, - { url = "https://files.pythonhosted.org/packages/f3/9d/666d856cc3af3a62ae86393baa3074cc1d591a47d89dc3bf16f6eb2c8d32/aiohttp-3.12.13-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:177f52420cde4ce0bb9425a375d95577fe082cb5721ecb61da3049b55189e4e6", size = 1718082, upload-time = "2025-06-14T15:14:12.38Z" }, - { url = "https://files.pythonhosted.org/packages/f3/ce/3c185293843d17be063dada45efd2712bb6bf6370b37104b4eda908ffdbd/aiohttp-3.12.13-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0f7df1f620ec40f1a7fbcb99ea17d7326ea6996715e78f71a1c9a021e31b96b8", size = 1633884, upload-time = "2025-06-14T15:14:14.415Z" }, - { url = "https://files.pythonhosted.org/packages/3a/5b/f3413f4b238113be35dfd6794e65029250d4b93caa0974ca572217745bdb/aiohttp-3.12.13-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:3062d4ad53b36e17796dce1c0d6da0ad27a015c321e663657ba1cc7659cfc710", size = 1694943, upload-time = "2025-06-14T15:14:16.48Z" }, - { url = "https://files.pythonhosted.org/packages/82/c8/0e56e8bf12081faca85d14a6929ad5c1263c146149cd66caa7bc12255b6d/aiohttp-3.12.13-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:8605e22d2a86b8e51ffb5253d9045ea73683d92d47c0b1438e11a359bdb94462", size = 1716398, upload-time = "2025-06-14T15:14:18.589Z" }, - { url = "https://files.pythonhosted.org/packages/ea/f3/33192b4761f7f9b2f7f4281365d925d663629cfaea093a64b658b94fc8e1/aiohttp-3.12.13-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:54fbbe6beafc2820de71ece2198458a711e224e116efefa01b7969f3e2b3ddae", size = 1657051, upload-time = "2025-06-14T15:14:20.223Z" }, - { url = "https://files.pythonhosted.org/packages/5e/0b/26ddd91ca8f84c48452431cb4c5dd9523b13bc0c9766bda468e072ac9e29/aiohttp-3.12.13-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:050bd277dfc3768b606fd4eae79dd58ceda67d8b0b3c565656a89ae34525d15e", size = 1736611, upload-time = "2025-06-14T15:14:21.988Z" }, - { url = "https://files.pythonhosted.org/packages/c3/8d/e04569aae853302648e2c138a680a6a2f02e374c5b6711732b29f1e129cc/aiohttp-3.12.13-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:2637a60910b58f50f22379b6797466c3aa6ae28a6ab6404e09175ce4955b4e6a", size = 1764586, upload-time = "2025-06-14T15:14:23.979Z" }, - { url = "https://files.pythonhosted.org/packages/ac/98/c193c1d1198571d988454e4ed75adc21c55af247a9fda08236602921c8c8/aiohttp-3.12.13-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:e986067357550d1aaa21cfe9897fa19e680110551518a5a7cf44e6c5638cb8b5", size = 1724197, upload-time = "2025-06-14T15:14:25.692Z" }, - { url = "https://files.pythonhosted.org/packages/e7/9e/07bb8aa11eec762c6b1ff61575eeeb2657df11ab3d3abfa528d95f3e9337/aiohttp-3.12.13-cp312-cp312-win32.whl", hash = "sha256:ac941a80aeea2aaae2875c9500861a3ba356f9ff17b9cb2dbfb5cbf91baaf5bf", size = 421771, upload-time = "2025-06-14T15:14:27.364Z" }, - { url = "https://files.pythonhosted.org/packages/52/66/3ce877e56ec0813069cdc9607cd979575859c597b6fb9b4182c6d5f31886/aiohttp-3.12.13-cp312-cp312-win_amd64.whl", hash = "sha256:671f41e6146a749b6c81cb7fd07f5a8356d46febdaaaf07b0e774ff04830461e", size = 447869, upload-time = "2025-06-14T15:14:29.05Z" }, + { url = "https://files.pythonhosted.org/packages/c3/0d/29026524e9336e33d9767a1e593ae2b24c2b8b09af7c2bd8193762f76b3e/aiohttp-3.12.14-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:a0ecbb32fc3e69bc25efcda7d28d38e987d007096cbbeed04f14a6662d0eee22", size = 701055, upload-time = "2025-07-10T13:03:45.59Z" }, + { url = "https://files.pythonhosted.org/packages/0a/b8/a5e8e583e6c8c1056f4b012b50a03c77a669c2e9bf012b7cf33d6bc4b141/aiohttp-3.12.14-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:0400f0ca9bb3e0b02f6466421f253797f6384e9845820c8b05e976398ac1d81a", size = 475670, upload-time = "2025-07-10T13:03:47.249Z" }, + { url = "https://files.pythonhosted.org/packages/29/e8/5202890c9e81a4ec2c2808dd90ffe024952e72c061729e1d49917677952f/aiohttp-3.12.14-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:a56809fed4c8a830b5cae18454b7464e1529dbf66f71c4772e3cfa9cbec0a1ff", size = 468513, upload-time = "2025-07-10T13:03:49.377Z" }, + { url = "https://files.pythonhosted.org/packages/23/e5/d11db8c23d8923d3484a27468a40737d50f05b05eebbb6288bafcb467356/aiohttp-3.12.14-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:27f2e373276e4755691a963e5d11756d093e346119f0627c2d6518208483fb6d", size = 1715309, upload-time = "2025-07-10T13:03:51.556Z" }, + { url = "https://files.pythonhosted.org/packages/53/44/af6879ca0eff7a16b1b650b7ea4a827301737a350a464239e58aa7c387ef/aiohttp-3.12.14-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:ca39e433630e9a16281125ef57ece6817afd1d54c9f1bf32e901f38f16035869", size = 1697961, upload-time = "2025-07-10T13:03:53.511Z" }, + { url = "https://files.pythonhosted.org/packages/bb/94/18457f043399e1ec0e59ad8674c0372f925363059c276a45a1459e17f423/aiohttp-3.12.14-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9c748b3f8b14c77720132b2510a7d9907a03c20ba80f469e58d5dfd90c079a1c", size = 1753055, upload-time = "2025-07-10T13:03:55.368Z" }, + { url = "https://files.pythonhosted.org/packages/26/d9/1d3744dc588fafb50ff8a6226d58f484a2242b5dd93d8038882f55474d41/aiohttp-3.12.14-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f0a568abe1b15ce69d4cc37e23020720423f0728e3cb1f9bcd3f53420ec3bfe7", size = 1799211, upload-time = "2025-07-10T13:03:57.216Z" }, + { url = "https://files.pythonhosted.org/packages/73/12/2530fb2b08773f717ab2d249ca7a982ac66e32187c62d49e2c86c9bba9b4/aiohttp-3.12.14-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9888e60c2c54eaf56704b17feb558c7ed6b7439bca1e07d4818ab878f2083660", size = 1718649, upload-time = "2025-07-10T13:03:59.469Z" }, + { url = "https://files.pythonhosted.org/packages/b9/34/8d6015a729f6571341a311061b578e8b8072ea3656b3d72329fa0faa2c7c/aiohttp-3.12.14-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3006a1dc579b9156de01e7916d38c63dc1ea0679b14627a37edf6151bc530088", size = 1634452, upload-time = "2025-07-10T13:04:01.698Z" }, + { url = "https://files.pythonhosted.org/packages/ff/4b/08b83ea02595a582447aeb0c1986792d0de35fe7a22fb2125d65091cbaf3/aiohttp-3.12.14-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:aa8ec5c15ab80e5501a26719eb48a55f3c567da45c6ea5bb78c52c036b2655c7", size = 1695511, upload-time = "2025-07-10T13:04:04.165Z" }, + { url = "https://files.pythonhosted.org/packages/b5/66/9c7c31037a063eec13ecf1976185c65d1394ded4a5120dd5965e3473cb21/aiohttp-3.12.14-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:39b94e50959aa07844c7fe2206b9f75d63cc3ad1c648aaa755aa257f6f2498a9", size = 1716967, upload-time = "2025-07-10T13:04:06.132Z" }, + { url = "https://files.pythonhosted.org/packages/ba/02/84406e0ad1acb0fb61fd617651ab6de760b2d6a31700904bc0b33bd0894d/aiohttp-3.12.14-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:04c11907492f416dad9885d503fbfc5dcb6768d90cad8639a771922d584609d3", size = 1657620, upload-time = "2025-07-10T13:04:07.944Z" }, + { url = "https://files.pythonhosted.org/packages/07/53/da018f4013a7a179017b9a274b46b9a12cbeb387570f116964f498a6f211/aiohttp-3.12.14-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:88167bd9ab69bb46cee91bd9761db6dfd45b6e76a0438c7e884c3f8160ff21eb", size = 1737179, upload-time = "2025-07-10T13:04:10.182Z" }, + { url = "https://files.pythonhosted.org/packages/49/e8/ca01c5ccfeaafb026d85fa4f43ceb23eb80ea9c1385688db0ef322c751e9/aiohttp-3.12.14-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:791504763f25e8f9f251e4688195e8b455f8820274320204f7eafc467e609425", size = 1765156, upload-time = "2025-07-10T13:04:12.029Z" }, + { url = "https://files.pythonhosted.org/packages/22/32/5501ab525a47ba23c20613e568174d6c63aa09e2caa22cded5c6ea8e3ada/aiohttp-3.12.14-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:2785b112346e435dd3a1a67f67713a3fe692d288542f1347ad255683f066d8e0", size = 1724766, upload-time = "2025-07-10T13:04:13.961Z" }, + { url = "https://files.pythonhosted.org/packages/06/af/28e24574801fcf1657945347ee10df3892311c2829b41232be6089e461e7/aiohttp-3.12.14-cp312-cp312-win32.whl", hash = "sha256:15f5f4792c9c999a31d8decf444e79fcfd98497bf98e94284bf390a7bb8c1729", size = 422641, upload-time = "2025-07-10T13:04:16.018Z" }, + { url = "https://files.pythonhosted.org/packages/98/d5/7ac2464aebd2eecac38dbe96148c9eb487679c512449ba5215d233755582/aiohttp-3.12.14-cp312-cp312-win_amd64.whl", hash = "sha256:3b66e1a182879f579b105a80d5c4bd448b91a57e8933564bf41665064796a338", size = 449316, upload-time = "2025-07-10T13:04:18.289Z" }, ] [[package]] name = "aiosignal" -version = "1.3.2" +version = "1.4.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "frozenlist" }, + { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/ba/b5/6d55e80f6d8a08ce22b982eafa278d823b541c925f11ee774b0b9c43473d/aiosignal-1.3.2.tar.gz", hash = "sha256:a8c255c66fafb1e499c9351d0bf32ff2d8a0321595ebac3b93713656d2436f54", size = 19424, upload-time = "2024-12-13T17:10:40.86Z" } +sdist = { url = "https://files.pythonhosted.org/packages/61/62/06741b579156360248d1ec624842ad0edf697050bbaf7c3e46394e106ad1/aiosignal-1.4.0.tar.gz", hash = "sha256:f47eecd9468083c2029cc99945502cb7708b082c232f9aca65da147157b251c7", size = 25007, upload-time = "2025-07-03T22:54:43.528Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/ec/6a/bc7e17a3e87a2985d3e8f4da4cd0f481060eb78fb08596c42be62c90a4d9/aiosignal-1.3.2-py2.py3-none-any.whl", hash = "sha256:45cde58e409a301715980c2b01d0c28bdde3770d8290b5eb2173759d9acb31a5", size = 7597, upload-time = "2024-12-13T17:10:38.469Z" }, + { url = "https://files.pythonhosted.org/packages/fb/76/641ae371508676492379f16e2fa48f4e2c11741bd63c48be4b12a6b09cba/aiosignal-1.4.0-py3-none-any.whl", hash = "sha256:053243f8b92b990551949e63930a839ff0cf0b0ebbe0597b0f3fb19e1a0fe82e", size = 7490, upload-time = "2025-07-03T22:54:42.156Z" }, ] [[package]] @@ -1973,21 +1974,21 @@ wheels = [ [[package]] name = "pillow" -version = "11.2.1" +version = "11.3.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/af/cb/bb5c01fcd2a69335b86c22142b2bccfc3464087efb7fd382eee5ffc7fdf7/pillow-11.2.1.tar.gz", hash = "sha256:a64dd61998416367b7ef979b73d3a85853ba9bec4c2925f74e588879a58716b6", size = 47026707, upload-time = "2025-04-12T17:50:03.289Z" } +sdist = { url = "https://files.pythonhosted.org/packages/f3/0d/d0d6dea55cd152ce3d6767bb38a8fc10e33796ba4ba210cbab9354b6d238/pillow-11.3.0.tar.gz", hash = "sha256:3828ee7586cd0b2091b6209e5ad53e20d0649bbe87164a459d0676e035e8f523", size = 47113069, upload-time = "2025-07-01T09:16:30.666Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/c7/40/052610b15a1b8961f52537cc8326ca6a881408bc2bdad0d852edeb6ed33b/pillow-11.2.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:78afba22027b4accef10dbd5eed84425930ba41b3ea0a86fa8d20baaf19d807f", size = 3190185, upload-time = "2025-04-12T17:48:00.417Z" }, - { url = "https://files.pythonhosted.org/packages/e5/7e/b86dbd35a5f938632093dc40d1682874c33dcfe832558fc80ca56bfcb774/pillow-11.2.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:78092232a4ab376a35d68c4e6d5e00dfd73454bd12b230420025fbe178ee3b0b", size = 3030306, upload-time = "2025-04-12T17:48:02.391Z" }, - { url = "https://files.pythonhosted.org/packages/a4/5c/467a161f9ed53e5eab51a42923c33051bf8d1a2af4626ac04f5166e58e0c/pillow-11.2.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:25a5f306095c6780c52e6bbb6109624b95c5b18e40aab1c3041da3e9e0cd3e2d", size = 4416121, upload-time = "2025-04-12T17:48:04.554Z" }, - { url = "https://files.pythonhosted.org/packages/62/73/972b7742e38ae0e2ac76ab137ca6005dcf877480da0d9d61d93b613065b4/pillow-11.2.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0c7b29dbd4281923a2bfe562acb734cee96bbb129e96e6972d315ed9f232bef4", size = 4501707, upload-time = "2025-04-12T17:48:06.831Z" }, - { url = "https://files.pythonhosted.org/packages/e4/3a/427e4cb0b9e177efbc1a84798ed20498c4f233abde003c06d2650a6d60cb/pillow-11.2.1-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:3e645b020f3209a0181a418bffe7b4a93171eef6c4ef6cc20980b30bebf17b7d", size = 4522921, upload-time = "2025-04-12T17:48:09.229Z" }, - { url = "https://files.pythonhosted.org/packages/fe/7c/d8b1330458e4d2f3f45d9508796d7caf0c0d3764c00c823d10f6f1a3b76d/pillow-11.2.1-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:b2dbea1012ccb784a65349f57bbc93730b96e85b42e9bf7b01ef40443db720b4", size = 4612523, upload-time = "2025-04-12T17:48:11.631Z" }, - { url = "https://files.pythonhosted.org/packages/b3/2f/65738384e0b1acf451de5a573d8153fe84103772d139e1e0bdf1596be2ea/pillow-11.2.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:da3104c57bbd72948d75f6a9389e6727d2ab6333c3617f0a89d72d4940aa0443", size = 4587836, upload-time = "2025-04-12T17:48:13.592Z" }, - { url = "https://files.pythonhosted.org/packages/6a/c5/e795c9f2ddf3debb2dedd0df889f2fe4b053308bb59a3cc02a0cd144d641/pillow-11.2.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:598174aef4589af795f66f9caab87ba4ff860ce08cd5bb447c6fc553ffee603c", size = 4669390, upload-time = "2025-04-12T17:48:15.938Z" }, - { url = "https://files.pythonhosted.org/packages/96/ae/ca0099a3995976a9fce2f423166f7bff9b12244afdc7520f6ed38911539a/pillow-11.2.1-cp312-cp312-win32.whl", hash = "sha256:1d535df14716e7f8776b9e7fee118576d65572b4aad3ed639be9e4fa88a1cad3", size = 2332309, upload-time = "2025-04-12T17:48:17.885Z" }, - { url = "https://files.pythonhosted.org/packages/7c/18/24bff2ad716257fc03da964c5e8f05d9790a779a8895d6566e493ccf0189/pillow-11.2.1-cp312-cp312-win_amd64.whl", hash = "sha256:14e33b28bf17c7a38eede290f77db7c664e4eb01f7869e37fa98a5aa95978941", size = 2676768, upload-time = "2025-04-12T17:48:19.655Z" }, - { url = "https://files.pythonhosted.org/packages/da/bb/e8d656c9543276517ee40184aaa39dcb41e683bca121022f9323ae11b39d/pillow-11.2.1-cp312-cp312-win_arm64.whl", hash = "sha256:21e1470ac9e5739ff880c211fc3af01e3ae505859392bf65458c224d0bf283eb", size = 2415087, upload-time = "2025-04-12T17:48:21.991Z" }, + { url = "https://files.pythonhosted.org/packages/40/fe/1bc9b3ee13f68487a99ac9529968035cca2f0a51ec36892060edcc51d06a/pillow-11.3.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:fdae223722da47b024b867c1ea0be64e0df702c5e0a60e27daad39bf960dd1e4", size = 5278800, upload-time = "2025-07-01T09:14:17.648Z" }, + { url = "https://files.pythonhosted.org/packages/2c/32/7e2ac19b5713657384cec55f89065fb306b06af008cfd87e572035b27119/pillow-11.3.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:921bd305b10e82b4d1f5e802b6850677f965d8394203d182f078873851dada69", size = 4686296, upload-time = "2025-07-01T09:14:19.828Z" }, + { url = "https://files.pythonhosted.org/packages/8e/1e/b9e12bbe6e4c2220effebc09ea0923a07a6da1e1f1bfbc8d7d29a01ce32b/pillow-11.3.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:eb76541cba2f958032d79d143b98a3a6b3ea87f0959bbe256c0b5e416599fd5d", size = 5871726, upload-time = "2025-07-03T13:10:04.448Z" }, + { url = "https://files.pythonhosted.org/packages/8d/33/e9200d2bd7ba00dc3ddb78df1198a6e80d7669cce6c2bdbeb2530a74ec58/pillow-11.3.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:67172f2944ebba3d4a7b54f2e95c786a3a50c21b88456329314caaa28cda70f6", size = 7644652, upload-time = "2025-07-03T13:10:10.391Z" }, + { url = "https://files.pythonhosted.org/packages/41/f1/6f2427a26fc683e00d985bc391bdd76d8dd4e92fac33d841127eb8fb2313/pillow-11.3.0-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:97f07ed9f56a3b9b5f49d3661dc9607484e85c67e27f3e8be2c7d28ca032fec7", size = 5977787, upload-time = "2025-07-01T09:14:21.63Z" }, + { url = "https://files.pythonhosted.org/packages/e4/c9/06dd4a38974e24f932ff5f98ea3c546ce3f8c995d3f0985f8e5ba48bba19/pillow-11.3.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:676b2815362456b5b3216b4fd5bd89d362100dc6f4945154ff172e206a22c024", size = 6645236, upload-time = "2025-07-01T09:14:23.321Z" }, + { url = "https://files.pythonhosted.org/packages/40/e7/848f69fb79843b3d91241bad658e9c14f39a32f71a301bcd1d139416d1be/pillow-11.3.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:3e184b2f26ff146363dd07bde8b711833d7b0202e27d13540bfe2e35a323a809", size = 6086950, upload-time = "2025-07-01T09:14:25.237Z" }, + { url = "https://files.pythonhosted.org/packages/0b/1a/7cff92e695a2a29ac1958c2a0fe4c0b2393b60aac13b04a4fe2735cad52d/pillow-11.3.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:6be31e3fc9a621e071bc17bb7de63b85cbe0bfae91bb0363c893cbe67247780d", size = 6723358, upload-time = "2025-07-01T09:14:27.053Z" }, + { url = "https://files.pythonhosted.org/packages/26/7d/73699ad77895f69edff76b0f332acc3d497f22f5d75e5360f78cbcaff248/pillow-11.3.0-cp312-cp312-win32.whl", hash = "sha256:7b161756381f0918e05e7cb8a371fff367e807770f8fe92ecb20d905d0e1c149", size = 6275079, upload-time = "2025-07-01T09:14:30.104Z" }, + { url = "https://files.pythonhosted.org/packages/8c/ce/e7dfc873bdd9828f3b6e5c2bbb74e47a98ec23cc5c74fc4e54462f0d9204/pillow-11.3.0-cp312-cp312-win_amd64.whl", hash = "sha256:a6444696fce635783440b7f7a9fc24b3ad10a9ea3f0ab66c5905be1c19ccf17d", size = 6986324, upload-time = "2025-07-01T09:14:31.899Z" }, + { url = "https://files.pythonhosted.org/packages/16/8f/b13447d1bf0b1f7467ce7d86f6e6edf66c0ad7cf44cf5c87a37f9bed9936/pillow-11.3.0-cp312-cp312-win_arm64.whl", hash = "sha256:2aceea54f957dd4448264f9bf40875da0415c83eb85f55069d89c0ed436e3542", size = 2423067, upload-time = "2025-07-01T09:14:33.709Z" }, ] [[package]] @@ -2014,6 +2015,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/fe/39/979e8e21520d4e47a0bbe349e2713c0aac6f3d853d0e5b34d76206c439aa/platformdirs-4.3.8-py3-none-any.whl", hash = "sha256:ff7059bb7eb1179e2685604f4aaf157cfd9535242bd23742eadc3c13542139b4", size = 18567, upload-time = "2025-05-07T22:47:40.376Z" }, ] +[[package]] +name = "plotly" +version = "6.2.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "narwhals" }, + { name = "packaging" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/6e/5c/0efc297df362b88b74957a230af61cd6929f531f72f48063e8408702ffba/plotly-6.2.0.tar.gz", hash = "sha256:9dfa23c328000f16c928beb68927444c1ab9eae837d1fe648dbcda5360c7953d", size = 6801941, upload-time = "2025-06-26T16:20:45.765Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ed/20/f2b7ac96a91cc5f70d81320adad24cc41bf52013508d649b1481db225780/plotly-6.2.0-py3-none-any.whl", hash = "sha256:32c444d4c940887219cb80738317040363deefdfee4f354498cc0b6dab8978bd", size = 9635469, upload-time = "2025-06-26T16:20:40.76Z" }, +] + [[package]] name = "pluggy" version = "1.6.0" @@ -2844,15 +2858,15 @@ wheels = [ [[package]] name = "starlette" -version = "0.47.1" +version = "0.47.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "anyio" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/0a/69/662169fdb92fb96ec3eaee218cf540a629d629c86d7993d9651226a6789b/starlette-0.47.1.tar.gz", hash = "sha256:aef012dd2b6be325ffa16698f9dc533614fb1cebd593a906b90dc1025529a79b", size = 2583072, upload-time = "2025-06-21T04:03:17.337Z" } +sdist = { url = "https://files.pythonhosted.org/packages/04/57/d062573f391d062710d4088fa1369428c38d51460ab6fedff920efef932e/starlette-0.47.2.tar.gz", hash = "sha256:6ae9aa5db235e4846decc1e7b79c4f346adf41e9777aebeb49dfd09bbd7023d8", size = 2583948, upload-time = "2025-07-20T17:31:58.522Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/82/95/38ef0cd7fa11eaba6a99b3c4f5ac948d8bc6ff199aabd327a29cc000840c/starlette-0.47.1-py3-none-any.whl", hash = "sha256:5e11c9f5c7c3f24959edbf2dffdc01bba860228acf657129467d8a7468591527", size = 72747, upload-time = "2025-06-21T04:03:15.705Z" }, + { url = "https://files.pythonhosted.org/packages/f7/1f/b876b1f83aef204198a42dc101613fefccb32258e5428b5f9259677864b4/starlette-0.47.2-py3-none-any.whl", hash = "sha256:c5847e96134e5c5371ee9fac6fdf1a67336d5815e09eb2a01fdb57a351ef915b", size = 72984, upload-time = "2025-07-20T17:31:56.738Z" }, ] [[package]] @@ -3113,6 +3127,7 @@ dependencies = [ { name = "gcsfs" }, { name = "imageio", extra = ["ffmpeg"] }, { name = "ipdb" }, + { name = "ipykernel" }, { name = "ipython" }, { name = "iteround" }, { name = "line-profiler" }, @@ -3122,6 +3137,7 @@ dependencies = [ { name = "orjson" }, { name = "ortools" }, { name = "pandas" }, + { name = "plotly" }, { name = "polars" }, { name = "pyarrow" }, { name = "pymunk" }, @@ -3136,6 +3152,7 @@ dependencies = [ { name = "tqdm" }, { name = "unum" }, { name = "vivarium-core" }, + { name = "vl-convert-python" }, { name = "xmltodict" }, ] @@ -3168,6 +3185,7 @@ requires-dist = [ { name = "gcsfs" }, { name = "imageio", extras = ["ffmpeg"] }, { name = "ipdb" }, + { name = "ipykernel" }, { name = "ipython" }, { name = "iteround" }, { name = "jupyter", marker = "extra == 'dev'" }, @@ -3181,6 +3199,7 @@ requires-dist = [ { name = "orjson" }, { name = "ortools", specifier = "<9.11" }, { name = "pandas" }, + { name = "plotly" }, { name = "polars" }, { name = "pre-commit", marker = "extra == 'dev'" }, { name = "pyarrow" }, @@ -3201,6 +3220,7 @@ requires-dist = [ { name = "tqdm" }, { name = "unum" }, { name = "vivarium-core" }, + { name = "vl-convert-python" }, { name = "xmltodict" }, ] provides-extras = ["dev", "docs"] @@ -3236,6 +3256,19 @@ dependencies = [ ] sdist = { url = "https://files.pythonhosted.org/packages/3e/87/794e0b4c5dccbca3036152fe5df56860a57e70f3e68ac0198dbd7df60fcb/vivarium-core-1.6.5.tar.gz", hash = "sha256:1d83faa60005304b548f623447ab8675a06bb7ed8f6b7c0bd25b4aaa3381fccb", size = 136102, upload-time = "2024-12-03T21:49:29.797Z" } +[[package]] +name = "vl-convert-python" +version = "1.8.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/15/08/06945bff9655c5b0520a8d1b2550cd8007e106ebec45a33840035420e0d2/vl_convert_python-1.8.0.tar.gz", hash = "sha256:ceca613ca5551c55270a15ca48d0f3a7de1e949e0f127310e9b0f6570ea3fbbb", size = 4651586, upload-time = "2025-05-28T00:06:47.6Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/22/5a/9dca7d8ff56e82c298e9ef381cfc803e262b85b7c59f2515d0e9f81a75b6/vl_convert_python-1.8.0-cp37-abi3-macosx_10_12_x86_64.whl", hash = "sha256:f663317fc280b07553534195c1e31c4ca882d9c8601430211b078196db5ed227", size = 29956698, upload-time = "2025-05-28T00:06:29.533Z" }, + { url = "https://files.pythonhosted.org/packages/42/e2/325e6b5895482b2534e7462c012f237c66ffb02fb3af45eec0accab2f8d4/vl_convert_python-1.8.0-cp37-abi3-macosx_11_0_arm64.whl", hash = "sha256:81f6380019ceadf070a79f85aa624475a6568093f70de0e151a32e91ecbcaacf", size = 28831173, upload-time = "2025-05-28T00:06:32.925Z" }, + { url = "https://files.pythonhosted.org/packages/09/fa/1dd944c9e9898e59e31c385bdce215aca543acc555de20b8bf4dc60ddb89/vl_convert_python-1.8.0-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3388e3913287867b3553c10f81ca2d85268216a5a75e7c71b9c1b59887c1977e", size = 31668750, upload-time = "2025-05-28T00:06:36.158Z" }, + { url = "https://files.pythonhosted.org/packages/a7/6b/48f6d47a92eaf6f0dd235146307a7eb0d179b78d2faebc53aca3f1e49177/vl_convert_python-1.8.0-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b51264998e8fcc43dbce801484a950cfe6513cdc4c46b20604ef50989855a617", size = 32970141, upload-time = "2025-05-28T00:06:41.323Z" }, + { url = "https://files.pythonhosted.org/packages/f8/6f/29dce05f9167e3a01ab74d79eeadd531bc24cf59e3a7fc3736af476ca431/vl_convert_python-1.8.0-cp37-abi3-win_amd64.whl", hash = "sha256:9f1146b791ed27916f54c45e1d66af53a40eb26e5aaea1892f33eb9a935039ab", size = 31318167, upload-time = "2025-05-28T00:06:44.881Z" }, +] + [[package]] name = "wcwidth" version = "0.2.13"