Skip to content

Commit

Permalink
Merge pull request #44 from johli/revision-upd-4
Browse files Browse the repository at this point in the history
Revision update
  • Loading branch information
johli authored Oct 8, 2024
2 parents 16815fa + 195a66d commit 28ca864
Show file tree
Hide file tree
Showing 8 changed files with 323 additions and 237 deletions.
16 changes: 16 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,22 @@ Documentation page: https://calico.github.io/baskerville/index.html
`cd baskerville`
`pip install .`

To set up the required environment variables:
`cd baskerville`
`conda activate <conda_env>`
`./env_vars.sh`

*Note:* Change the two lines of code at the top of './env_vars.sh' to the correct local paths.

Alternatively, the environment variables can be set manually:
```sh
export BASKERVILLE_DIR=/home/<user_path>/baskerville
export PATH=$BASKERVILLE_DIR/src/baskerville/scripts:$PATH
export PYTHONPATH=$BASKERVILLE_DIR/src/baskerville/scripts:$PYTHONPATH

export BASKERVILLE_CONDA=/home/<user>/anaconda3/etc/profile.d/conda.sh
```

---

#### Contacts
Expand Down
33 changes: 33 additions & 0 deletions env_vars.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#!/bin/bash

# set these variables before running the script
LOCAL_BASKERVILLE_PATH="/home/jlinder/baskerville"
LOCAL_CONDA_PATH="/home/jlinder/anaconda3/etc/profile.d/conda.sh"

# create env_vars sh scripts in local conda env
mkdir -p "$CONDA_PREFIX/etc/conda/activate.d"
mkdir -p "$CONDA_PREFIX/etc/conda/deactivate.d"

file_vars_act="$CONDA_PREFIX/etc/conda/activate.d/env_vars.sh"
if ! [ -e $file_vars_act ]; then
echo '#!/bin/sh' > $file_vars_act
fi

file_vars_deact="$CONDA_PREFIX/etc/conda/deactivate.d/env_vars.sh"
if ! [ -e $file_vars_deact ]; then
echo '#!/bin/sh' > $file_vars_deact
fi

# append env variable exports to /activate.d/env_vars.sh
echo "export BASKERVILLE_DIR=$LOCAL_BASKERVILLE_PATH" >> $file_vars_act
echo 'export PATH=$BASKERVILLE_DIR/src/baskerville/scripts:$PATH' >> $file_vars_act
echo 'export PYTHONPATH=$BASKERVILLE_DIR/src/baskerville/scripts:$PYTHONPATH' >> $file_vars_act

echo "export BASKERVILLE_CONDA=$LOCAL_CONDA_PATH" >> $file_vars_act

# append env variable unsets to /deactivate.d/env_vars.sh
echo 'unset BASKERVILLE_DIR' >> $file_vars_deact
echo 'unset BASKERVILLE_CONDA' >> $file_vars_deact

# finally activate env variables
source $file_vars_act
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ dependencies = [
"tabulate~=0.8.10",
"tensorflow~=2.15.0",
"tqdm~=4.65.0",
"pyfaidx~=0.7.1",
]

[project.optional-dependencies]
Expand Down
53 changes: 50 additions & 3 deletions src/baskerville/gene.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,11 @@ class Gene:
"""Class for managing genes in an isoform-agnostic way, taking
the union of exons across isoforms."""

def __init__(self, chrom, strand, kv):
def __init__(self, chrom, strand, kv, name=None):
self.chrom = chrom
self.strand = strand
self.kv = kv
self.name = name
self.exons = IntervalTree()

def add_exon(self, start, end):
Expand All @@ -77,10 +78,53 @@ def span(self):
exon_starts = [exon.begin for exon in self.exons]
exon_ends = [exon.end for exon in self.exons]
return min(exon_starts), max(exon_ends)

def output_slice_old(self, seq_start, seq_len, model_stride, span=False):
gene_slice = []

if span:
gene_start, gene_end = self.span()

# clip left boundaries
gene_seq_start = max(0, gene_start - seq_start)
gene_seq_end = max(0, gene_end - seq_start)

# requires >50% overlap
slice_start = int(np.round(gene_seq_start / model_stride))
slice_end = int(np.round(gene_seq_end / model_stride))

# clip right boundaries
slice_max = int(seq_len/model_stride)
slice_start = min(slice_start, slice_max)
slice_end = min(slice_end, slice_max)

gene_slice = range(slice_start, slice_end)

else:
for exon in self.get_exons():
# clip left boundaries
exon_seq_start = max(0, exon.begin - seq_start)
exon_seq_end = max(0, exon.end - seq_start)

# requires >50% overlap
slice_start = int(np.round(exon_seq_start / model_stride))
slice_end = int(np.round(exon_seq_end / model_stride))

# clip right boundaries
slice_max = int(seq_len/model_stride)
slice_start = min(slice_start, slice_max)
slice_end = min(slice_end, slice_max)

gene_slice.extend(range(slice_start, slice_end))

return np.array(gene_slice)

def output_slice(
self, seq_start, seq_len, model_stride, span=False, majority_overlap=False
self, seq_start, seq_len, model_stride, span=False, majority_overlap=False, old_version=False
):
if old_version :
return self.output_slice_old(seq_start, seq_len, model_stride, span=span)

gene_slice = []

def clip_boundaries(slice_start, slice_end):
Expand Down Expand Up @@ -162,10 +206,13 @@ def read_gtf(self, gtf_file):
strand = a[6]
kv = gtf_kv(a[8])
gene_id = kv["gene_id"]
gene_name = None
if 'gene_name' in kv:
gene_name = kv['gene_name']

# initialize gene
if gene_id not in self.genes:
self.genes[gene_id] = Gene(chrom, strand, kv)
self.genes[gene_id] = Gene(chrom, strand, kv, gene_name)

# add exon
self.genes[gene_id].add_exon(start - 1, end)
Expand Down
19 changes: 17 additions & 2 deletions src/baskerville/scripts/hound_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,9 @@ def main():
help="Generate cross fold split [Default: %default]",
)
parser.add_option(
"-g", dest="gaps_file", help="Genome assembly gaps BED [Default: %default]"
"-g",
dest="gaps_file",
help="Genome assembly gaps BED [Default: %default]"
)
parser.add_option(
"-i",
Expand Down Expand Up @@ -194,7 +196,11 @@ def main():
type="str",
help="Proportion of the data for testing [Default: %default]",
)
parser.add_option("-u", dest="umap_bed", help="Unmappable regions in BED format")
parser.add_option(
"-u",
dest="umap_bed",
help="Unmappable regions in BED format"
)
parser.add_option(
"--umap_t",
dest="umap_t",
Expand Down Expand Up @@ -230,6 +236,13 @@ def main():
type="str",
help="Proportion of the data for validation [Default: %default]",
)
parser.add_option(
"--transform_old",
dest="transform_old",
default=False,
action="store_true",
help="Apply old target transforms [Default: %default]",
)
(options, args) = parser.parse_args()

if len(args) != 2:
Expand Down Expand Up @@ -493,6 +506,8 @@ def main():
cmd += " -b %s" % options.blacklist_bed
if options.interp_nan:
cmd += " -i"
if options.transform_old:
cmd += " --transform_old"
cmd += " %s" % genome_cov_file
cmd += " %s" % seqs_bed_file
cmd += " %s" % seqs_cov_file
Expand Down
136 changes: 93 additions & 43 deletions src/baskerville/scripts/hound_data_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,13 @@ def main():
type="int",
help="Average pooling width [Default: %default]",
)
parser.add_option(
"--transform_old",
dest="transform_old",
default=False,
action="store_true",
help="Apply old target transforms [Default: %default]",
)
(options, args) = parser.parse_args()

if len(args) != 3:
Expand Down Expand Up @@ -180,49 +187,92 @@ def main():
# crop
if options.crop_bp > 0:
seq_cov_nt = seq_cov_nt[options.crop_bp : -options.crop_bp]

# scale
seq_cov_nt = options.scale * seq_cov_nt

# sum pool
seq_cov = seq_cov_nt.reshape(target_length, options.pool_width)
if options.sum_stat == "sum":
seq_cov = seq_cov.sum(axis=1, dtype="float32")
elif options.sum_stat == "sum_sqrt":
seq_cov = seq_cov.sum(axis=1, dtype="float32")
seq_cov = -1 + np.sqrt(1 + seq_cov)
elif options.sum_stat == "sum_exp75":
seq_cov = seq_cov.sum(axis=1, dtype="float32")
seq_cov = -1 + (1 + seq_cov) ** 0.75
elif options.sum_stat in ["mean", "avg"]:
seq_cov = seq_cov.mean(axis=1, dtype="float32")
elif options.sum_stat in ["mean_sqrt", "avg_sqrt"]:
seq_cov = seq_cov.mean(axis=1, dtype="float32")
seq_cov = -1 + np.sqrt(1 + seq_cov)
elif options.sum_stat == "median":
seq_cov = seq_cov.median(axis=1)
elif options.sum_stat == "max":
seq_cov = seq_cov.max(axis=1)
elif options.sum_stat == "peak":
seq_cov = seq_cov.mean(axis=1, dtype="float32")
seq_cov = np.clip(np.sqrt(seq_cov * 4), 0, 1)
else:
print(
'ERROR: Unrecognized summary statistic "%s".' % options.sum_stat,
file=sys.stderr,
)
exit(1)

# clip
if options.clip_soft is not None:
clip_mask = seq_cov > options.clip_soft
seq_cov[clip_mask] = (
options.clip_soft
- 1
+ np.sqrt(seq_cov[clip_mask] - options.clip_soft + 1)
)
if options.clip is not None:
seq_cov = np.clip(seq_cov, -options.clip, options.clip)

# apply original transform (from borzoi manuscript)
if options.transform_old:
# sum pool
seq_cov = seq_cov_nt.reshape(target_length, options.pool_width)
if options.sum_stat == 'sum':
seq_cov = seq_cov.sum(axis=1, dtype='float32')
elif options.sum_stat == 'sum_sqrt':
seq_cov = seq_cov.sum(axis=1, dtype='float32')
seq_cov = seq_cov**0.75
elif options.sum_stat in ['mean', 'avg']:
seq_cov = seq_cov.mean(axis=1, dtype='float32')
elif options.sum_stat in ['mean_sqrt', 'avg_sqrt']:
seq_cov = seq_cov.mean(axis=1, dtype='float32')
seq_cov = seq_cov**0.75
elif options.sum_stat == 'median':
seq_cov = seq_cov.median(axis=1)
elif options.sum_stat == 'max':
seq_cov = seq_cov.max(axis=1)
elif options.sum_stat == 'peak':
seq_cov = seq_cov.mean(axis=1, dtype='float32')
seq_cov = np.clip(np.sqrt(seq_cov*4), 0, 1)
else:
print(
'ERROR: Unrecognized summary statistic "%s".' % options.sum_stat,
file=sys.stderr
)
exit(1)

# clip
if options.clip_soft is not None:
clip_mask = seq_cov > options.clip_soft
seq_cov[clip_mask] = (
options.clip_soft
+ np.sqrt(seq_cov[clip_mask] - options.clip_soft)
)
if options.clip is not None:
seq_cov = np.clip(seq_cov, -options.clip, options.clip)

# scale
seq_cov = options.scale * seq_cov

else : #apply new (updated) transform

# scale
seq_cov_nt = options.scale * seq_cov_nt

# sum pool
seq_cov = seq_cov_nt.reshape(target_length, options.pool_width)
if options.sum_stat == "sum":
seq_cov = seq_cov.sum(axis=1, dtype="float32")
elif options.sum_stat == "sum_sqrt":
seq_cov = seq_cov.sum(axis=1, dtype="float32")
seq_cov = -1 + np.sqrt(1 + seq_cov)
elif options.sum_stat == "sum_exp75":
seq_cov = seq_cov.sum(axis=1, dtype="float32")
seq_cov = -1 + (1 + seq_cov) ** 0.75
elif options.sum_stat in ["mean", "avg"]:
seq_cov = seq_cov.mean(axis=1, dtype="float32")
elif options.sum_stat in ["mean_sqrt", "avg_sqrt"]:
seq_cov = seq_cov.mean(axis=1, dtype="float32")
seq_cov = -1 + np.sqrt(1 + seq_cov)
elif options.sum_stat == "median":
seq_cov = seq_cov.median(axis=1)
elif options.sum_stat == "max":
seq_cov = seq_cov.max(axis=1)
elif options.sum_stat == "peak":
seq_cov = seq_cov.mean(axis=1, dtype="float32")
seq_cov = np.clip(np.sqrt(seq_cov * 4), 0, 1)
else:
print(
'ERROR: Unrecognized summary statistic "%s".' % options.sum_stat,
file=sys.stderr,
)
exit(1)

# clip
if options.clip_soft is not None:
clip_mask = seq_cov > options.clip_soft
seq_cov[clip_mask] = (
options.clip_soft
- 1
+ np.sqrt(seq_cov[clip_mask] - options.clip_soft + 1)
)
if options.clip is not None:
seq_cov = np.clip(seq_cov, -options.clip, options.clip)

# clip float16 min/max
seq_cov = np.clip(seq_cov, np.finfo(np.float16).min, np.finfo(np.float16).max)
Expand Down
Loading

0 comments on commit 28ca864

Please sign in to comment.