"""Reader for PRS weights files"""
from typing import Union
import logging
import hail as hl
import hailtop.fs as hfs
from aoutools._utils.helpers import SimpleTimer
from ._utils import (
_stage_local_file_to_gcs,
_standardize_chromosome_column,
)
logger = logging.getLogger(__name__)
def _validate_alleles(
table: hl.Table
) -> hl.Table:
"""
Filters out rows with invalid alleles (non-ACGT characters).
This function handles both SNPs and indels by checking that all
characters in the allele strings are standard DNA bases.
Parameters
----------
table : hail.Table
A Hail Table containing `effect_allele` and `noneffect_allele` fields
to validate.
Returns
-------
hail.Table
A filtered table with only rows that contain valid alleles.
"""
logger.info("Validating allele columns for non-ACGT characters...")
dna_regex = '^[ACGT]+$'
initial_count = table.count()
table = table.filter(
hl.str(table.effect_allele).matches(dna_regex) &
hl.str(table.noneffect_allele).matches(dna_regex)
)
final_count = table.count()
n_removed = initial_count - final_count
if n_removed > 0:
logger.warning(
"Removed %d variants with invalid alleles (non-ACGT characters found).",
n_removed
)
return table
def _check_duplicated_ids(
table: hl.Table,
file_path: str = "input"
) -> None:
"""
Checks for duplicate variants based on genomic identifiers.
This function constructs a unique variant ID by concatenating the
chromosome, position, and alleles. It then checks for duplicates and
raises an error if any are found.
Parameters
----------
table : hl.Table
A Hail Table to validate. Must contain the fields 'chr', 'pos',
'noneffect_allele', and 'effect_allele'.
file_path : str, optional
A source file path to display in error messages. Default is "input".
Returns
-------
None
Raises
------
ValueError
If any duplicate variants are found based on the constructed ID.
"""
logger.info(
"Checking for duplicate variants based on chr, pos, and alleles..."
)
table_with_id = table.annotate(
variant_id=hl.str('_').join([
table.chr, hl.str(table.pos), table.noneffect_allele,
table.effect_allele
])
)
id_counts_table = table_with_id.group_by('variant_id').aggregate(
n=hl.agg.count()
)
duplicate_variants = id_counts_table.filter(id_counts_table.n > 1)
if duplicate_variants.count() > 0:
example_duplicates = duplicate_variants.take(5)
formatted_examples = [d.variant_id for d in example_duplicates]
raise ValueError(
f"Duplicate variants found in '{file_path}'. Examples of "
f"duplicate IDs: {formatted_examples}."
)
logger.info("No duplicate variants found.")
def _process_prs_weights_table(
table: hl.Table,
file_path: str,
validate_alleles: bool
) -> hl.Table:
"""
Performs final filtering and validation steps on an imported weights table.
This function consolidates the shared post-processing logic, including:
1. Optionally validating that allele columns contain only ACGT characters.
2. Standardizing the chromosome column to ensure it has a 'chr' prefix.
3. Filtering out variants with undefined (missing) or zero-effect weights.
4. Checking if the table is empty after all filtering.
5. Checking for duplicate variants.
6. Logging the final count of loaded variants.
Parameters
----------
table : hail.Table
A Hail table immediately after import and column standardization.
file_path : str
A source file path used for logging and error messages.
validate_alleles : bool
If True, validates that allele columns contain only ACGT characters.
Returns
-------
hail.Table
A fully processed and validated Hail Table.
Raises
------
ValueError
If the table is empty after filtering for missing/zero weights, or if
duplicate variants (defined by chromosome, position, and alleles)
are found in the table.
"""
if validate_alleles:
table = _validate_alleles(table)
table = _standardize_chromosome_column(table)
# Get the count before filtering based on weight
count_before_filter = table.count()
# Filter out variants with missing or zero-effect weights
table = table.filter(
(hl.is_defined(table.weight)) & (table.weight != 0)
)
# Persist the table here as we need to perform multiple actions on it
table = table.persist()
filtered_row_count = table.count()
# Log the number of variants removed due to weight issues
n_removed = count_before_filter - filtered_row_count
if n_removed > 0:
logger.info(
"Removed %d variants with missing or zero-effect weights.",
n_removed
)
if filtered_row_count == 0:
raise ValueError(
f"Input file '{file_path}' is empty or all variants were "
f"filtered out due to missing/zero weights or invalid alleles."
)
_check_duplicated_ids(table, file_path=file_path)
logger.info(
"Successfully loaded %d variants from %s.",
filtered_row_count,
file_path
)
return table
def _read_prs_weights_noheader(
#pylint: disable=too-many-arguments
#pylint: disable=too-many-positional-arguments
file_path: str,
column_map: dict,
delimiter: str,
comment: Union[str, list[str]],
keep_other_cols: bool = False,
validate_alleles: bool = False,
**kwargs
) -> hl.Table:
"""
Reads a weight file without a header using a column map of indices.
This internal function handles the specifics of importing a header-less
file, mapping the provided integer indices to standardized column names,
and then passing the resulting table to the main processing function.
Parameters
----------
file_path : str
A path to the weight file.
column_map : dict
A dictionary mapping standard names to 1-based integer indices.
delimiter : str
A field delimiter.
comment : str or list[str], optional
A character, or list of characters, that denote comment lines
to be ignored.
keep_other_cols : bool
If True, all non-required columns are preserved.
validate_alleles : bool
If True, validates that allele columns contain only ACGT characters.
**kwargs : dict, optional
Other keyword arguments to pass directly to `hail.import_table`, such
as `missing` or `min_partitions`.
Returns
-------
hail.Table
A processed Hail Table.
Raises
------
ValueError
If `column_map` contains invalid indices (e.g., non-1-based or
duplicates), if the table is empty after filtering for missing
weights, or if duplicate variants are found.
"""
logger.info("Importing file (no header): '%s'", file_path)
indices = list(column_map.values())
if any(i < 1 for i in indices):
raise ValueError(
"Column indices in column_map must be 1-based and cannot be "
"less than 1."
)
if len(indices) != len(set(indices)):
raise ValueError("Duplicate column indices provided in column_map.")
table = hl.import_table(
file_path,
delimiter=delimiter,
no_header=True,
comment=comment,
**kwargs
)
standard_cols_exprs = {
'chr': table[f"f{column_map['chr'] - 1}"],
'pos': hl.int32(table[f"f{column_map['pos'] - 1}"]),
'effect_allele': table[f"f{column_map['effect_allele'] - 1}"],
'noneffect_allele': table[f"f{column_map['noneffect_allele'] - 1}"],
'weight': hl.float64(table[f"f{column_map['weight'] - 1}"])
}
other_cols_exprs = {}
if keep_other_cols:
used_f_fields = {f'f{i - 1}' for i in indices}
all_f_fields = table.row_value.dtype.fields
other_fields = [f for f in all_f_fields if f not in used_f_fields]
new_names = [f'non_req_col_{i+1}' for i, _ in enumerate(other_fields)]
other_cols_exprs = {
new: table[old] for new, old in zip(new_names, other_fields)
}
table = table.select(**standard_cols_exprs, **other_cols_exprs)
return _process_prs_weights_table(table, file_path, validate_alleles)
def _read_prs_weights_header(
#pylint: disable=too-many-arguments
#pylint: disable=too-many-positional-arguments
file_path: str,
column_map: dict,
delimiter: str,
comment: Union[str, list[str]],
keep_other_cols: bool = False,
validate_alleles: bool = False,
**kwargs
) -> hl.Table:
"""
Reads a weight file with a header using a column map of names.
This internal function handles the specifics of importing a file with a
header, mapping the provided column names to standardized names, and
then passing the resulting table to the main processing function.
Parameters
----------
file_path : str
A path to the weight file.
column_map : dict
A dictionary mapping standard names to user-defined column names.
delimiter : str
A field delimiter.
comment : str or list[str], optional
A character, or list of characters, that denote comment lines
to be ignored. Default is '#'.
keep_other_cols : bool
If True, all non-required columns are preserved.
validate_alleles : bool
If True, validates that allele columns contain only ACGT characters.
**kwargs : dict, optional
Other keyword arguments to pass directly to `hail.import_table`, such
as `missing` or `min_partitions`.
Returns
-------
hail.Table
A processed Hail Table.
Raises
------
ValueError
If `column_map` contains duplicate names, if specified columns
are not in the file's header, if the table is empty after
filtering for missing weights, or if duplicate variants are found.
"""
logger.info("Importing file (with header): '%s'", file_path)
col_names = list(column_map.values())
if len(col_names) != len(set(col_names)):
raise ValueError("Duplicate column names provided in column_map.")
types = {
column_map['chr']: hl.tstr, column_map['pos']: hl.tint32,
column_map['effect_allele']: hl.tstr,
column_map['noneffect_allele']: hl.tstr,
column_map['weight']: hl.tfloat64,
}
table = hl.import_table(
file_path,
delimiter=delimiter,
no_header=False,
types=types,
comment=comment,
**kwargs
)
missing = set(col_names) - set(table.row)
if missing:
raise ValueError(f"Required columns not in header: {missing}")
standard_exprs = {
'chr': table[column_map['chr']], 'pos': table[column_map['pos']],
'effect_allele': table[column_map['effect_allele']],
'noneffect_allele': table[column_map['noneffect_allele']],
'weight': table[column_map['weight']]
}
other_exprs = {}
if keep_other_cols:
other_fields = [f for f in table.row if f not in col_names]
other_exprs = {f: table[f] for f in other_fields}
table = table.select(**standard_exprs, **other_exprs)
return _process_prs_weights_table(table, file_path, validate_alleles)
def _validate_column_map_type(column_map: dict, header: bool):
"""
Validate that all values in `column_map` are of the expected type based on
`header`.
Parameters
----------
column_map : dict
A dictionary mapping standard keys to column names or indices.
header : bool
Indicates if the input file has a header row.
- If True, all values in `column_map` must be strings (column names).
- If False, all values must be integers (1-based column indices).
Raises
------
TypeError
If any value in `column_map` does not match the expected type based on
`header`.
"""
expected_type = str if header else int
if not all(isinstance(v, expected_type) for v in column_map.values()):
raise TypeError(
f"With header={header}, column_map values must be "
f"{expected_type.__name__}s."
)
[docs]
def read_prs_weights(
#pylint: disable=too-many-arguments
#pylint: disable=too-many-positional-arguments
file_path: str,
header: bool,
column_map: dict[str, Union[str, int]],
delimiter: str = ',',
comment: Union[str, list[str]] = '#',
keep_other_cols: bool = False,
validate_alleles: bool = False,
**kwargs
) -> hl.Table:
"""
Reads a file containing variant effect weights for PRS calculation.
This function requires an active Hail-enabled environment. It uses a
flexible `column_map` dictionary to handle various input file formats.
After standardizing the required columns, the function performs several
validation checks, filtering out variants with missing weights, invalid
alleles (if `validate_alleles=True`), or raising an error for duplicates.
If a local file path is provided, it is automatically copied to a temporary
directory in your GCS bucket for Hail access.
Parameters
----------
file_path : str
A path to the weight file (local or gs://).
header : bool
If True, `column_map` values should be strings (column names).
If False, `column_map` values should be 1-based integers (column indices).
column_map : dict
A dictionary mapping standard names to user-defined names or indices.
Must contain the keys: 'chr', 'pos', 'effect_allele',
'noneffect_allele', and 'weight'.
Example for header=True: {'chr': 'CHR', 'pos': 'BP', ...}
Example for header=False: {'chr': 1, 'pos': 2, ...}
delimiter : str, default ','
A field delimiter.
comment : str or list[str], default '#'
A character, or list of characters, that denote comment lines
to be ignored.
keep_other_cols : bool, default False
If True, all columns not specified in `column_map` are preserved.
validate_alleles : bool, default False
If True, validates that allele columns contain only ACGT characters.
**kwargs : dict, optional
Other keyword arguments to pass directly to `hail.import_table`, such
as `missing` or `min_partitions`.
Returns
-------
hail.Table
A Hail Table with standardized columns ready for PRS calculation.
Raises
------
ValueError
If `column_map` is missing required keys, if the input file is empty,
or if duplicate variants are found in the weights file.
TypeError
If the value types in `column_map` do not match the `header`
setting (e.g., strings for `header=True`, integers for `header=False`).
FileNotFoundError
If a local `file_path` is provided and the file does not exist.
"""
timer = SimpleTimer()
with timer:
gcs_path = _stage_local_file_to_gcs(file_path, sub_dir='temp_prs_data')
required_keys = {
'chr', 'pos', 'effect_allele', 'noneffect_allele', 'weight'
}
if not required_keys.issubset(column_map.keys()):
missing = required_keys - set(column_map.keys())
raise ValueError(f"column_map is missing required keys: {missing}")
try:
if hfs.stat(gcs_path).size == 0:
raise ValueError(f"Input file '{file_path}' is empty.")
except hl.utils.java.FatalError as e:
if 'Is a directory' not in str(e):
raise
_validate_column_map_type(column_map, header)
parser_func = (
_read_prs_weights_header if header else _read_prs_weights_noheader
)
result_table = parser_func(
file_path=gcs_path,
column_map=column_map,
delimiter=delimiter,
comment=comment,
keep_other_cols=keep_other_cols,
validate_alleles=validate_alleles,
**kwargs
)
logger.info(
"Weights file reading complete. Total time: %.2f seconds.",
timer.duration
)
return result_table
[docs]
def read_prscs(
file_path: str,
**kwargs
) -> hl.Table:
"""
A simple wrapper to read PRS-CS output files.
This function assumes a standard PRS-CS output format, which is a
header-less, tab-separated file with the following columns:
1. Chromosome
2. Variant ID
3. Base Position
4. Effect Allele (A1)
5. Non-Effect Allele (A2)
6. Posterior Effect Size (weight)
Note: The second column (Variant ID) is not loaded by default, as it is
not required for the core functionality. To preserve this and any other
columns, set `keep_other_cols=True` when calling this function.
Parameters
----------
file_path : str
A path to the PRS-CS output file.
**kwargs
Other optional arguments to pass to `read_prs_weights`, such as
`keep_other_cols` or `validate_alleles`.
Returns
-------
hail.Table
A processed Hail Table of the PRS-CS weights.
"""
logger.info("Reading PRS-CS file: %s", file_path)
prscs_map = {
'chr': 1, 'pos': 3, 'effect_allele': 4,
'noneffect_allele': 5, 'weight': 6
}
return read_prs_weights(
file_path=file_path,
header=False,
column_map=prscs_map,
delimiter='\t',
**kwargs
)