"""PRS calculator"""
from typing import Optional
import logging
import hail as hl
import hailtop.fs as hfs
import pandas as pd
from aoutools._utils.helpers import SimpleTimer
from ._utils import _log_timing
from ._calculator_utils import (
_prepare_samples_to_keep,
_orient_weights_for_split,
_check_allele_match,
_calculate_dosage,
_prepare_weights_for_chunking,
_create_1bp_intervals,
)
from ._config import PRSConfig
logger = logging.getLogger(__name__)
def _prepare_mt_split(
vds: hl.vds.VariantDataset,
weights_table: hl.Table,
config: PRSConfig
) -> hl.MatrixTable:
"""
Prepares a MatrixTable for the split-multi PRS calculation path.
Prepares a MatrixTable for split-multi PRS calculation.
Splits multi-allelic sites in the VDS, orients weights based on
`ref_is_effect_allele`, joins with the weights table using (locus,
alleles), and calculates dosage using GT after splitting.
Parameters
----------
vds : hail.vds.VariantDataset
An interval-filtered VariantDataset.
weights_table : hail.Table
A chunk of the PRS weights table.
config : PRSConfig
A configuration object controlling PRS behavior, including
`split_multi`, `ref_is_effect_allele`, and `detailed_timings`.
Returns
-------
hail.MatrixTable
A MatrixTable annotated with `weights_info` and per-variant `dosage`.
See also
--------
PRSConfig : A configuration class that holds parameters for PRS
calculation.
"""
with _log_timing(
"Planning: Splitting multi-allelic variants and joining",
config.detailed_timings
):
mt = hl.vds.split_multi(vds).variant_data
weights_ht_processed = _orient_weights_for_split(weights_table, config)
mt = mt.annotate_rows(
weights_info=weights_ht_processed[mt.row_key]
)
mt = mt.filter_rows(hl.is_defined(mt.weights_info))
with _log_timing(
"Planning: Calculating per-variant dosage",
config.detailed_timings,
):
# After splitting, LGT is converted to GT, so we can
# directly and safely use the built-in dosage calculator.
# See the source code for `hl.vds.split_multi` for details.
mt = mt.annotate_entries(dosage=mt.GT.n_alt_alleles())
return mt
def _prepare_mt_non_split(
vds: hl.vds.VariantDataset,
weights_table: hl.Table,
config: PRSConfig
) -> hl.MatrixTable:
"""
Prepares a MatrixTable for the non-split PRS calculation path.
This function takes an interval-filtered VDS and joins it with the
weights table using a locus-based key. It optionally performs a strict
allele match to handle allele orientation and then calculates dosage
using a custom multi-allelic dosage function.
Parameters
----------
vds : hail.vds.VariantDataset
An interval-filtered Variant Dataset.
weights_table : hail.Table
A chunk of the weights table.
config : PRSConfig
A configuration object controlling `strict_allele_match` and
`detailed_timings`.
Returns
-------
hail.MatrixTable
A MatrixTable, filtered and annotated with `weights_info` and `dosage`
for the specified effect allele.
See also
--------
PRSConfig : A configuration class that holds parameters for PRS calculation.
"""
mt = vds.variant_data
with _log_timing(
"Planning: Annotating variants with weights", config.detailed_timings
):
mt = mt.annotate_rows(weights_info=weights_table[mt.locus])
mt = mt.filter_rows(hl.is_defined(mt.weights_info))
if config.strict_allele_match:
with _log_timing(
"Planning: Performing strict allele match",
config.detailed_timings
):
is_valid_pair = _check_allele_match(mt, mt.weights_info)
mt = mt.filter_rows(is_valid_pair)
with _log_timing(
"Planning: Calculating per-variant dosage",
config.detailed_timings,
):
mt = mt.annotate_entries(dosage=_calculate_dosage(mt))
return mt
def _calculate_prs_chunk(
weights_table: hl.Table,
vds: hl.vds.VariantDataset,
config: PRSConfig
) -> hl.Table:
"""
Calculates a Polygenic Risk Score (PRS) for a single chunk of variants.
This function serves as the core computation step. It prepares the variant
data depending on whether multi-allelic splitting is enabled, and computes
the PRS using dosage-weight aggregation.
Parameters
----------
weights_table : hail.Table
A pre-filtered chunk of the full weights table, keyed by 'locus'.
vds : hail.vds.VariantDataset
The Variant Dataset containing genotypes to score.
config : PRSConfig
A configuration object specifying settings such as `split_multi`,
`include_n_matched`, and `sample_id_col`.
Returns
-------
hail.Table
A Hail Table with one row per sample and a PRS column. If requested,
also includes the number of matched variants ('n_matched').
See also
--------
PRSConfig : A configuration class that holds parameters for PRS calculation.
"""
if config.split_multi:
mt = _prepare_mt_split(
vds=vds,
weights_table=weights_table,
config=config,
)
else:
mt = _prepare_mt_non_split(
vds=vds,
weights_table=weights_table,
config=config,
)
# Chunks aggregation
prs_table = mt.select_cols(
prs=hl.agg.sum(mt.dosage * mt.weights_info.weight)
).cols()
if config.include_n_matched:
with _log_timing(
"Computing shared variants count", config.detailed_timings
):
# Using hl.agg.count() within the `select_cols` block won't work
# since homozygous reference are set to missing while `agg.count`
# counts the number of rows for which that specific sample has a
# non-missing genotype calls.
# This is two-pass approach and thus less performant.
n_matched = mt.count_rows()
logger.info("%d variants in common in this chunk.", n_matched)
prs_table = prs_table.annotate(n_matched=n_matched)
# Rename sample ID column to user-defined name
prs_table = prs_table.rename({'s': config.sample_id_col})
# Drop all global annotations to minimize memory footprint
return prs_table.select_globals()
def _process_chunks(
full_weights_table: hl.Table,
n_chunks: int,
vds: hl.vds.VariantDataset,
config: PRSConfig
) -> list[pd.DataFrame]:
"""
Iteratively processes each chunk of the weights table.
This helper function orchestrates the main PRS calculation loop. For each
chunk, it filters the Variant Dataset to the relevant genomic intervals,
computes the PRS using `_calculate_prs_chunk`, and converts the result to a
Pandas DataFrame.
Parameters
----------
full_weights_table : hail.Table
The full weights table, annotated with a 'chunk_id' field.
n_chunks : int
The total number of chunks to process.
vds : hail.vds.VariantDataset
The Variant Dataset containing genotype data, optionally
filtered for samples.
config : PRSConfig
A configuration object specifying PRS settings, including
`detailed_timings` and `sample_id_col`.
Returns
-------
list[pd.DataFrame]
A list of Pandas DataFrames, where each DataFrame contains
the partial PRS results for one chunk.
See also
--------
PRSConfig : A configuration class that holds parameters for PRS calculation.
"""
partial_dfs = []
for i in range(n_chunks):
# Always show chunk processing time to track progress
with _log_timing(
f"Processing chunk {i + 1}/{n_chunks}", True
):
# Use .persist() to avoid recomputation of the same chunk in
# _calculate_prs_chunk, specifically during:
# 1. Creation of interval_ht
# 2. Annotating rows with PRS weight information
weights_chunk = full_weights_table.filter(
full_weights_table.chunk_id == i
).persist()
intervals_to_filter = _create_1bp_intervals(weights_chunk)
# If filter_intervals filters the main vds and reassigns to vds
# again, subsequent operation will try to filter empty variable.
vds_chunk = hl.vds.filter_intervals(
vds, intervals_to_filter, keep=True
)
chunk_prs_table = _calculate_prs_chunk(
weights_table=weights_chunk,
vds=vds_chunk,
config=config
)
# Convert the per-chunk Hail Table to a Pandas DataFrame.
partial_dfs.append(chunk_prs_table.to_pandas())
return partial_dfs
def _aggregate_and_export(
partial_dfs: list[pd.DataFrame],
output_path: str,
config: PRSConfig
) -> None:
"""
Aggregates partial Pandas DataFrame results and exports the final result.
This helper function handles the final aggregation and export stage of the
PRS pipeline. It concatenates a list of partial DataFrames, groups them by
sample ID, sums the PRS scores, and writes the final aggregated results to
a specified cloud storage path.
Parameters
----------
partial_dfs : list[pd.DataFrame]
A list of Pandas DataFrames, where each contains partial PRS results
for a chunk.
output_path : str
A destination path on GCS to write the final comma-separated file.
config : PRSConfig
A configuration object that specifies `sample_id_col` and
`detailed_timings`.
Returns
-------
None
See also
--------
PRSConfig : A configuration class that holds parameters for PRS calculation.
"""
if not partial_dfs:
logger.warning(
"No PRS results were generated. No output file will be created."
)
return
with _log_timing(
"Aggregating results with Pandas", config.detailed_timings
):
combined_df = pd.concat(partial_dfs, ignore_index=True)
final_df = combined_df.groupby(config.sample_id_col).sum()
with _log_timing(
f"Exporting final result to {output_path}", config.detailed_timings
):
with hfs.open(output_path, 'w') as f:
final_df.to_csv(f, sep=',', index=True, header=True)
[docs]
def calculate_prs(
weights_table: hl.Table,
vds: hl.vds.VariantDataset,
output_path: str,
config: PRSConfig = PRSConfig()
) -> Optional[str]:
"""
Calculates a Polygenic Risk Score (PRS) and exports the result to a file.
This function is the main entry point for the PRS calculation workflow. It
processes a weights table in chunks, using a filter_intervals approach to
select variants from the VDS for each chunk. Partial results are then
converted to Pandas DataFrames and aggregated to produce the final score
file.
Notes
-----
By default (`config.split_multi=True`), this function prioritizes
robustness over performance by splitting multi-allelic variants.
This split_multi process includes creating a minimal representation for
variants. For example, for a variant chr1:10075251 A/G in the weights
table, split_multi can intelligently match it to a complex indel in the VDS
(e.g., alleles=['AGGGC', 'A', 'GGGGC']) by simplifying the VDS
representation to its minimal form (['A', 'G']) for 'AGGGC' -> 'GGGGC'.
The non-split path (`config.split_multi=False`) is a faster but less robust
alternative. It relies on a direct string comparison of alleles and will
fail to match the complex variant described above. Furthermore, if the
weights table contains multiple entries for the same locus, the non-split
path will arbitrarily select only one of them. This "power-user" option
should only be used if you are certain that both your VDS and weights table
contain only simple, well-matched, bi-allelic variants.
Parameters
----------
weights_table : hail.Table
A Hail table containing variant weights. Must contain the following
columns:
- `chr`: str
- `pos`: int32
- `effect_allele`: str
- `noneffect_allele`: str
- A column for the effect weight (float64), specified by
`weight_col_name`.
vds : hail.vds.VariantDataset
A Hail VariantDataset containing both variant and sample data.
output_path : str
A GCS path (starting with 'gs://') to write the final comma-separated
output file.
config : PRSConfig, optional
A configuration object for all optional parameters. If not provided,
default settings will be used. See the `PRSConfig` class for details
on all available settings.
Returns
-------
str or None
The output path if results are successfully written; otherwise, None.
The output file is a comma-separated text file with:
- A sample ID column (as configured in `config.sample_id_col`)
- `prs`: The calculated PRS value
- `n_matched` (optional): The number of variants used to calculate
the score, included if `config.include_n_matched` is True.
Raises
------
ValueError
If `output_path` is not a valid GCS path, or if the `weights_table`
is empty after validation.
TypeError
If the `config.samples_to_keep` argument is of an unsupported type.
See also
--------
PRSConfig : A configuration class that holds parameters for PRS calculation.
"""
timer = SimpleTimer()
with timer:
if not output_path.startswith('gs://'):
raise ValueError(
"The 'output_path' must be a Google Cloud Storage (GCS) "
"path, starting with 'gs://'."
)
logger.info(
"Starting PRS calculation. Final result will be at: %s",
output_path,
)
if config.samples_to_keep is not None:
with _log_timing(
"Planning: Filtering to specified samples",
config.detailed_timings
):
samples_ht = _prepare_samples_to_keep(config.samples_to_keep)
vds = hl.vds.filter_samples(vds, samples_ht)
full_weights_table, n_chunks = _prepare_weights_for_chunking(
weights_table=weights_table,
config=config,
validate_table=True,
)
partial_dfs = _process_chunks(
full_weights_table=full_weights_table,
n_chunks=n_chunks,
vds=vds,
config=config,
)
_aggregate_and_export(
partial_dfs=partial_dfs,
output_path=output_path,
config=config,
)
# Report the total time using the duration captured by the context manager
logger.info(
"PRS calculation complete. Total time: %.2f seconds.", timer.duration
)
return output_path if partial_dfs else None