Source code for aoutools.prs._calculator_batch

"""PRS batch calculator"""

from typing import Optional
import logging
import hail as hl
import hailtop.fs as hfs
import pandas as pd
from aoutools._utils.helpers import SimpleTimer
from ._utils import _log_timing
from ._calculator_utils import (
    _prepare_samples_to_keep,
    _validate_and_prepare_weights_table,
    _orient_weights_for_split,
    _check_allele_match,
    _calculate_dosage,
    _prepare_weights_for_chunking,
    _create_1bp_intervals,
)
from ._config import PRSConfig

logger = logging.getLogger(__name__)


def _prepare_batch_weights_data(
    weights_tables_map: dict[str, hl.Table],
    config: PRSConfig,
) -> tuple[dict, hl.Table]:
    """
    Prepares multiple weights tables for batch PRS calculation.

    This function validates and formats each weights table according to the
    selected calculation mode (split or non-split). It also builds a union of
    all unique loci across tables, which will later be used to filter the
    Variant Dataset (VDS).

    Parameters
    ----------
    weights_tables_map : dict[str, hl.Table]
        A dictionary mapping score names to Hail tables containing PRS weights.
    config : PRSConfig
        A configuration object controlling behavior such as whether to split
        multi-allelic variants.

    Returns
    -------

    tuple[dict, hl.Table]
        A tuple containing:
        - A dictionary of prepared weights tables formatted for PRS calculation.
        - A Hail table containing all unique loci to keep for filtering.
          Returns `None` if no tables are provided.

    See also
    --------
    PRSConfig : A configuration class that holds parameters for PRS calculation.
    """
    prepared_weights = {}
    all_loci_tables = []
    with _log_timing(
        "Preparing all weights tables", config.detailed_timings
    ):
        for score_name, weights_table in weights_tables_map.items():
            prepared_table = _validate_and_prepare_weights_table(
                weights_table=weights_table,
                config=config
            )
            prepared_weights[score_name] = prepared_table
            all_loci_tables.append(prepared_table.select())

    if not all_loci_tables:
        return {}, None

    # If split_multi, create a new dictionary with tables formatted for that
    # path. Otherwise, the already prepared tables are used.
    if config.split_multi:
        final_prepared_weights = {}
        with _log_timing(
            "Re-keying weights tables for split-multi join",
            config.detailed_timings
        ):
            for score_name, ht in prepared_weights.items():
                final_prepared_weights[score_name] = \
                    _orient_weights_for_split(ht, config)
    else:
        final_prepared_weights = prepared_weights

    loci_to_keep = hl.Table.union(*all_loci_tables).key_by('locus').distinct()
    return final_prepared_weights, loci_to_keep


def _build_row_annotations(
    mt: hl.MatrixTable,
    mt_key: hl.expr.StructExpression,
    weights_tables_map: dict[str, hl.Table],
    prepared_weights: dict[str, hl.Table],
    config: PRSConfig,
) -> dict[str, hl.expr.Expression]:
    """
    Builds a dictionary of row annotations for PRS calculation.

    Each annotation includes:
    - `weights_info_{score}`: A struct containing the weights row matched to
      the current MatrixTable row for the given score.
    - `is_valid_{score}`: A boolean expression indicating whether a valid match
      was found for that score.

    In non-split mode, if `strict_allele_match` is enabled, the function also
    verifies allele consistency between the MatrixTable and weights table.

    Parameters
    ----------
    mt : hail.MatrixTable
        A MatrixTable containing genotype data.
    mt_key : hl.expr.StructExpression
        A struct expression (e.g., `hl.struct(locus=..., alleles=...)`) used to
        join against each weights table.
    weights_tables_map : dict[str, hl.Table]
        A dictionary mapping score names to the original weights tables.
    prepared_weights : dict[str, hl.Table]
        A dictionary of validated and possibly re-keyed weights tables,
        prepared for lookup during annotation.
    config : PRSConfig
        A configuration object that controls behavior such as whether to split
        multi-allelic variants and enforce strict allele matching.

    Returns
    -------
    dict[str, hl.expr.Expression]
        A dictionary mapping annotation names to Hail expressions, to be used as
        row fields in the MatrixTable.

    See also
    --------
    PRSConfig : A configuration class that holds parameters for PRS calculation.
    """
    annotations = {}
    for score_name in weights_tables_map:
        weights_info_expr = prepared_weights[score_name][mt_key]
        is_valid_expr = hl.is_defined(weights_info_expr)

        if not config.split_multi and config.strict_allele_match:
            is_valid_expr &= _check_allele_match(mt, weights_info_expr)

        annotations[f'weights_info_{score_name}'] = weights_info_expr
        annotations[f'is_valid_{score_name}'] = is_valid_expr
    return annotations


def _build_prs_agg_expr(
    mt: hl.MatrixTable,
    score_name: str,
    config: PRSConfig,
) -> hl.expr.Aggregation:
    """
    Builds an aggregation expression to compute a Polygenic Risk Score (PRS)
    for a given score.

    This expression calculates the sum of dosage multiplied by weight across
    all valid variants. A variant is considered valid if it is matched in the
    weights table and, if applicable, passes the allele-matching check.

    The dosage computation differs based on whether multi-allelic variants have
    been split. In split mode, `n_alt_alleles()` is used directly. In non-split
    mode, dosage is computed using the custom logic defined in
    `_calculate_dosage`.

    Parameters
    ----------
    mt : hail.MatrixTable
        A MatrixTable containing genotype data and row-level annotations
        produced by `_build_row_annotations`, including `weights_info_{score}`
        and `is_valid_{score}`.
    score_name : str
        A string identifier for the PRS score to compute. Used to look up the
        relevant annotations.
    config : PRSConfig
        A configuration object that determines whether the input data is split
        and how dosage is calculated.

    Returns
    -------
    hl.expr.Aggregation
        An aggregation expression representing the sum of `dosage * weight`
        across all valid variants for the given score.

    See also
    --------
    PRSConfig : A configuration class that holds parameters for PRS calculation.
    """
    weights_info = mt[f'weights_info_{score_name}']
    is_valid = mt[f'is_valid_{score_name}']

    # Use appropriate dosage calculation based on whether VDS is split
    if config.split_multi:
        dosage = mt.GT.n_alt_alleles()
    else:
        dosage = _calculate_dosage(mt, score_name)

    # Final score = sum of (dosage * weight) for valid variants
    return hl.agg.sum(hl.if_else(is_valid, dosage * weights_info.weight, 0.0))


def _calculate_prs_chunk_batch(
    vds: hl.vds.VariantDataset,
    weights_tables_map: dict[str, hl.Table],
    prepared_weights: dict[str, hl.Table],
    config: PRSConfig,
) -> hl.Table:
    """
    Calculates all Polygenic Risk Scores (PRS) for a single chunk of a
    VariantDataset (VDS).

    This function processes a subset of variants from a VDS and computes PRS
    values for all configured scores. It handles variant splitting, annotation
    with weight information, and dosage-based aggregation per sample.

    If configured, it also computes the number of valid variants (i.e., matched
    between the VDS and each weights table) used in score calculation.

    Parameters
    ----------
    vds : hl.vds.VariantDataset
        A VariantDataset chunk containing genotype and variant information.
    weights_tables_map : dict[str, hl.Table]
        A dictionary mapping score names to their original weights tables.
    prepared_weights : dict[str, hl.Table]
        A dictionary of weights tables that have been validated and formatted
        for PRS computation.
    config : PRSConfig
        A configuration object that controls behavior such as whether to split
        multi-allelic variants, whether to perform strict allele matching, and
        whether to include matched variant counts.

    Returns
    -------
    hl.Table
        A Hail Table with one row per sample and one column per PRS score. If
        `include_n_matched=True`, additional columns for the number of valid
        variants (e.g., `n_matched_score1`) are included.

    See also
    --------
    PRSConfig : A configuration class that holds parameters for PRS calculation.
    """
    # Step 1: Get MatrixTable from VDS, optionally splitting multi-allelics
    if config.split_multi:
        with _log_timing(
                "Planning: Splitting multi-allelic variants",
                config.detailed_timings
        ):
            mt = hl.vds.split_multi(vds).variant_data
            mt_key = mt.row_key
    else:
        mt = vds.variant_data
        mt_key = mt.locus

    # Step 2: Annotate MatrixTable rows with weights info and validity masks
    with _log_timing(
        "Planning: Calculating and aggregating PRS scores",
        config.detailed_timings
    ):
        row_annotations = _build_row_annotations(
            mt, mt_key, weights_tables_map, prepared_weights, config
        )
        mt = mt.annotate_rows(**row_annotations)

        # Step 3: Build score aggregators across columns (i.e., samples)
        score_aggregators = {
            score_name: _build_prs_agg_expr(mt, score_name, config)
            for score_name in weights_tables_map
        }

        # Compute and return the per-sample PRS results
        prs_table = mt.select_cols(**score_aggregators).cols().select_globals()

    # Step 4 (Optional): Compute number of matched variants (n_matched_*) if
    # requested
    if config.include_n_matched:
        with _log_timing(
                "Computing shared variants count", config.detailed_timings
        ):
            n_matched_aggs = {
                f'n_matched_{score_name}': hl.agg.count_where(
                    mt[f'is_valid_{score_name}']
                )
                for score_name in weights_tables_map
            }
            n_matched_counts = mt.aggregate_rows(hl.struct(**n_matched_aggs))
            prs_table = prs_table.annotate(**n_matched_counts)

    return prs_table


def _process_chunks_batch(
    #pylint: disable=too-many-arguments
    #pylint: disable=too-many-positional-arguments
    n_chunks: int,
    chunked_loci: hl.Table,
    vds: hl.vds.VariantDataset,
    weights_tables_map: dict[str, hl.Table],
    prepared_weights: dict[str, hl.Table],
    config: PRSConfig,
) -> list[pd.DataFrame]:
    """
    Processes each genomic chunk to compute PRS values in batch mode.

    This function iterates over genomic chunks defined in `chunked_loci`,
    filters the VariantDataset (VDS) to each chunk, and calculates Polygenic
    Risk Scores (PRS) for all configured scores. The results from each chunk
    are returned as a list of pandas DataFrames, one per chunk.

    Parameters
    ----------
    n_chunks : int
        The total number of genomic chunks to process.
    chunked_loci : hl.Table
        A Hail Table containing loci grouped by chunk ID.
    vds : hl.vds.VariantDataset
        A VariantDataset containing the full set of genotypes.
    weights_tables_map : dict[str, hl.Table]
        A dictionary mapping score names to their original weights tables.
    prepared_weights : dict[str, hl.Table]
        A dictionary of weights tables that have been validated and formatted
        for PRS computation.
    config : PRSConfig
        A configuration object that controls behavior such as whether to split
        multi-allelic variants and which sample ID column to use.

    Returns
    -------
    list[pd.DataFrame]
        A list of pandas DataFrames, one per chunk, each containing per-sample
        PRS results and optionally matched variant counts.

    See also
    --------
    PRSConfig : A configuration class that holds parameters for PRS calculation.
    """
    partial_dfs = []
    for i in range(n_chunks):
        # Always show chunk processing time to track progress
        with _log_timing(
            f"Processing chunk {i + 1}/{n_chunks}", True
        ):
            loci_chunk = chunked_loci.filter(
                chunked_loci.chunk_id == i
            ).persist()

            intervals_to_filter = _create_1bp_intervals(loci_chunk)
            # If filter_intervals filters the main vds and reassigns to vds
            # again, subsequent operation will try to filter empty variable.
            vds_chunk = hl.vds.filter_intervals(
                vds, intervals_to_filter, keep=True
            )

            # Filter the full weights table via semi-join to retain only the
            # current chunk's weights. semi_join preserves split variants at
            # the same locus when split_multi=True.
            chunked_prepared_weights = {
                score_name: table.semi_join(loci_chunk)
                for score_name, table in prepared_weights.items()
            }

            chunk_prs_table = _calculate_prs_chunk_batch(
                vds_chunk,
                weights_tables_map,
                chunked_prepared_weights,
                config
            )

            chunk_prs_table = chunk_prs_table.rename(
                {'s': config.sample_id_col}
            )
            partial_dfs.append(chunk_prs_table.to_pandas())
    return partial_dfs


def _aggregate_and_export_batch(
    partial_dfs: list[pd.DataFrame],
    output_path: str,
    config: PRSConfig
) -> None:
    """
    Aggregates partial PRS results from all chunks and exports to disk.

    This function combines the list of pandas DataFrames produced by
    `_process_chunks_batch`, sums PRS scores across all chunks for each sample,
    and writes the final per-sample results to a comma-delimited file.

    Parameters
    ----------
    partial_dfs : list[pd.DataFrame]
        A list of pandas DataFrames containing chunk-wise PRS results.
    output_path : str
        A destination path on GCS to write the final comma-separated file.
    config : PRSConfig
        A configuration object that specifies `sample_id_col` and
        `detailed_timings`.

    Returns
    -------
    None

    See also
    --------
    PRSConfig : A configuration class that holds parameters for PRS calculation.
    """
    if not partial_dfs:
        logger.warning(
            "No PRS results were generated. No output file will be created."
        )
        return

    with _log_timing(
            "Aggregating batch results with Pandas", config.detailed_timings
    ):
        combined_df = pd.concat(partial_dfs, ignore_index=True)
        final_df = combined_df.groupby(config.sample_id_col).sum()

    with _log_timing(
            f"Exporting final result to {output_path}", config.detailed_timings
    ):
        with hfs.open(output_path, 'w') as f:
            final_df.to_csv(f, sep=',', index=True, header=True)


[docs] def calculate_prs_batch( weights_tables_map: dict[str, hl.Table], vds: hl.vds.VariantDataset, output_path: str, config: PRSConfig = PRSConfig(), ) -> Optional[str]: """ Calculates multiple Polygenic Risk Scores (PRS) concurrently using a memory-efficient, per-score annotation approach. This function performs a batch PRS calculation on a Hail VariantDataset, using chunked aggregation and optional sample filtering. Parameters ---------- weights_tables_map : dict[str, hl.Table] A dictionary mapping score names to their corresponding PRS weights tables. vds : hl.vds.VariantDataset A Hail VariantDataset containing both variant and sample data. output_path : str A GCS path (starting with 'gs://') to write the final comma-separated output file. config : PRSConfig, optional A configuration object for all optional parameters. If not provided, default settings will be used. See the `PRSConfig` class for details on all available settings. Returns ------- Optional[str] The path to the final PRS result file if successful; otherwise, `None` if no valid variants were found. Raises ------ ValueError If `output_path` is not a valid GCS path, or if the `weights_table` is empty after validation. TypeError If the `config.samples_to_keep` argument is of an unsupported type. See also -------- PRSConfig : A configuration class that holds parameters for PRS calculation. """ timer = SimpleTimer() with timer: if not output_path.startswith('gs://'): raise ValueError( "The 'output_path' must be a Google Cloud Storage (GCS) " "path, starting with 'gs://'." ) logger.info( "Starting PRS calculation. Final result will be at: %s", output_path, ) if config.samples_to_keep is not None: with _log_timing( "Filtering to specified samples", config.detailed_timings ): samples_ht = _prepare_samples_to_keep(config.samples_to_keep) vds = hl.vds.filter_samples(vds, samples_ht) # Step 1: Prepare all weights data and get unique loci prepared_weights, loci_to_keep = \ _prepare_batch_weights_data(weights_tables_map, config) if loci_to_keep is None: logger.warning( "No variants found in any weights table. Aborting." ) return None # Step 2: Prepare loci for chunked processing count = loci_to_keep.count() logger.info( "Found %d total unique variants across all scores.", count ) chunked_loci, n_chunks = _prepare_weights_for_chunking( weights_table=loci_to_keep, config=config, validate_table=False ) # Step 3: Process all chunks using the new helper function partial_dfs = _process_chunks_batch( n_chunks=n_chunks, chunked_loci=chunked_loci, vds=vds, weights_tables_map=weights_tables_map, prepared_weights=prepared_weights, config=config, ) # Step 4: Aggregate and export final results _aggregate_and_export_batch( partial_dfs=partial_dfs, output_path=output_path, config=config ) logger.info( "Batch PRS calculation complete. Total time: %.2f seconds.", timer.duration ) return output_path if partial_dfs else None