Source code for aoutools.prs._workflow

"""
High-level workflow function for end-to-end PGS calculation.
"""

import logging
import tempfile
from pathlib import Path
from typing import Optional, Union, Iterable

import hail as hl

from ._downloader import download_pgs
from ._reader import read_prs_weights
from ._calculator_batch import calculate_prs_batch
from ._config import PRSConfig

logger = logging.getLogger(__name__)


[docs] def calculate_pgs( *, vds: hl.vds.VariantDataset, output_path: str, pgs: Union[Iterable[str], str], build: Optional[str] = "GRCh38", config: PRSConfig = PRSConfig(), user_agent: Optional[str] = None, verbose: bool = False, ) -> Optional[str]: """ Downloads specified PGS Catalog scoring files and calculates PRS. This function automates a controlled workflow: 1. **Download**: Fetches scoring files from the PGS Catalog for a specific list of PGS IDs. 2. **Read**: Parses the downloaded scoring files into Hail Tables. 3. **Calculate**: Computes the Polygenic Risk Score(s) for each downloaded file and exports a single CSV file. Notes ----- This function does not accept EFO traits or PGP publication IDs to prevent the unexpected download of a large number of scoring files, which may overwhelm the storage or computational resources of a typical workspace. Parameters ---------- vds : hail.vds.VariantDataset A Hail VariantDataset containing the genotype data to be scored. output_path : str A GCS path (e.g., 'gs://bucket/results.csv') for the output file. pgs : str or iterable of str One or more PGS Catalog ID(s) (e.g., "PGS000771") to download. This argument is required. build : str, optional The genome build for harmonized scores ("GRCh37" or "GRCh38"). Defaults to "GRCh38". config : PRSConfig, optional A configuration object for calculation parameters. user_agent : str, optional A custom user agent string for PGS Catalog API requests. verbose : bool, default False Enable verbose logging for the download process. Returns ------- str or None The output path if results are successfully written; otherwise, None. Raises ------ ValueError If the `output_path` is not a valid GCS path, or if a downloaded scoring file is empty, malformed, or contains duplicate variants. TypeError If `config.samples_to_keep` is an unsupported type. Exception If the download process fails due to network issues, invalid PGS IDs, or other errors from the underlying `pgscatalog-download` tool. """ # Define the column mapping for standard PGS Catalog scoring files pgs_column_map = { 'chr': 'hm_chr', 'pos': 'hm_pos', 'effect_allele': 'effect_allele', 'noneffect_allele': 'other_allele', 'weight': 'effect_weight', } with tempfile.TemporaryDirectory() as temp_dir: logger.info( "Created temporary directory for PGS downloads: %s", temp_dir ) # Step 1: Download scoring files for specified PGS IDs download_pgs( outdir=temp_dir, pgs=pgs, build=build, user_agent=user_agent, verbose=verbose, overwrite_existing_file=True, ) # Step 2: Read downloaded files and prepare for batch calculation downloaded_files = list(Path(temp_dir).rglob("*.txt.gz")) if not downloaded_files: logger.warning( "No scoring files were downloaded for the specified PGS IDs." ) return None weights_tables_map = {} for file_path in downloaded_files: score_name = file_path.name.split('.')[0] logger.info( "Reading scoring file for '%s' from %s", score_name, file_path ) try: weights_table = read_prs_weights( file_path=str(file_path), header=True, column_map=pgs_column_map, delimiter='\t', comment='#', validate_alleles=True, missing='', force=True, ) weights_tables_map[score_name] = weights_table except (ValueError, TypeError) as e: logger.warning( "Could not read scoring file for '%s'. Skipping. Error: %s. " "Please check the format of the downloaded file; some " "scoring files may contain variants without positional " "information (e.g., HLA types) or with incomplete effect " "allele information, which are not supported.", score_name, e ) if not weights_tables_map: logger.warning("No valid scoring files could be read. Aborting.") return None # Step 3: Calculate PRS for all scores in batch mode logger.info( "Calculating PRS for %d scores in batch mode.", len(weights_tables_map) ) result_path = calculate_prs_batch( weights_tables_map=weights_tables_map, vds=vds, output_path=output_path, config=config, ) return result_path