Skip to content

Commit

Permalink
only allow load trunk
Browse files Browse the repository at this point in the history
  • Loading branch information
hy395 committed Dec 26, 2024
1 parent 7f95a7a commit d6bd728
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 23 deletions.
27 changes: 15 additions & 12 deletions docs/transfer/transfer.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,16 @@ w5_folder=${data_path}/w5
mkdir -p ${data_path}
```

Download four replicate Borzoi pre-trained models (with identical train, validation and test splits (test = fold3, validation = fold4):
Download four replicate Borzoi pre-trained model trunks:

```bash
mkdir -p ${data_path}/weights
wget --progress=bar:force "https://storage.googleapis.com/seqnn-share/borzoi/f0/model0_best.h5" -O ${data_path}/weights/borzoi_r0.h5
wget --progress=bar:force "https://storage.googleapis.com/seqnn-share/borzoi/f1/model0_best.h5" -O ${data_path}/weights/borzoi_r1.h5
wget --progress=bar:force "https://storage.googleapis.com/seqnn-share/borzoi/f2/model0_best.h5" -O ${data_path}/weights/borzoi_r2.h5
wget --progress=bar:force "https://storage.googleapis.com/seqnn-share/borzoi/f3/model0_best.h5" -O ${data_path}/weights/borzoi_r3.h5
gsutil cp -r gs://scbasset_tutorial_data/baskerville_transfer/pretrain_trunks/ ${data_path}
```
Note:
- Four replicate models have identical train, validation and test splits (test on fold3, validation on fold4, trained on rest). More details in the Borzoi manuscript.
- Fold splits can be found in trainsplit/sequences.bed.
- Model trunk refers to the model weights without the final dense layer (head).


Download hg38 reference information, and train-validation-test-split information:

Expand Down Expand Up @@ -185,17 +186,19 @@ westminster_train_folds.py \
${data_path}/tfr
```

Run hound_transfer.py on fold3 data for 4 replicate models:
Run hound_transfer.py on training data in fold3 folder (identical to pre-train split) for four replicate models:

```bash
hound_transfer.py -o train_rep0 --restore ${data_path}/weights/borzoi_r0.h5 params.json train/f3c0/data0
hound_transfer.py -o train_rep1 --restore ${data_path}/weights/borzoi_r1.h5 params.json train/f3c0/data0
hound_transfer.py -o train_rep2 --restore ${data_path}/weights/borzoi_r2.h5 params.json train/f3c0/data0
hound_transfer.py -o train_rep3 --restore ${data_path}/weights/borzoi_r3.h5 params.json train/f3c0/data0
hound_transfer.py -o train_rep0 --trunk --restore ${data_path}/pretrain_trunks/borzoi_r0.h5 params.json train/f3c0/data0
hound_transfer.py -o train_rep1 --trunk --restore ${data_path}/pretrain_trunks/borzoi_r1.h5 params.json train/f3c0/data0
hound_transfer.py -o train_rep2 --trunk --restore ${data_path}/pretrain_trunks/borzoi_r2.h5 params.json train/f3c0/data0
hound_transfer.py -o train_rep3 --trunk --restore ${data_path}/pretrain_trunks/borzoi_r3.h5 params.json train/f3c0/data0
```

Note: we recommend loading the model trunk only. While it is possible to load full Borzoi model and ignore last dense layer by model.load_weights(weight_file, skip_mismatch=True, by_name=True), Tensorflow requires loading layer weight by name in this way. If layer name don't match, weights of the layer will not be loaded and no warning message will be given.

### Step 7. Load models

We apply weight merging for lora, ia3, and locon weights, and so there is no architecture changes once the model is trained. You can use the same params.json file, and load the train_rep0/model_best.mergeW.h5 weight file.

For houlsby and houlsby_se, model architectures change due to the insertion of adapter modules. New architecture json file can be found in train_rep0/params.json.
For houlsby and houlsby_se, model architectures change due to the insertion of adapter modules. New architecture json file is auto-generated in train_rep0/params.json;
8 changes: 3 additions & 5 deletions src/baskerville/scripts/hound_transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def main():
parser.add_argument(
"--restore",
default=None,
help="pre-trained weights.h5 [Default: %(default)s]",
help="model trunk h5 file [Default: %(default)s]",
)
parser.add_argument(
"--trunk",
Expand Down Expand Up @@ -180,10 +180,8 @@ def main():
seqnn_model = seqnn.SeqNN(params_model)

# restore
if args.trunk:
if args.restore:
seqnn_model.restore(args.restore, trunk=args.trunk)
else:
seqnn_model.restore(args.restore, pretrain=True)

# head params
print(
Expand Down Expand Up @@ -364,7 +362,7 @@ def main():

# restore
if args.restore:
seqnn_model.restore(args.restore, args.trunk)
seqnn_model.restore(args.restore, trunk=args.trunk)

# initialize trainer
seqnn_trainer = trainer.Trainer(
Expand Down
7 changes: 1 addition & 6 deletions src/baskerville/seqnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1020,15 +1020,10 @@ def predict_transform(

return preds

def restore(self, model_file, head_i=0, trunk=False, pretrain=False):
def restore(self, model_file, head_i=0, trunk=False):
"""Restore weights from saved model."""
if trunk:
self.model_trunk.load_weights(model_file)
elif pretrain:
self.models[head_i].load_weights(
model_file, by_name=True, skip_mismatch=True
)
self.model = self.models[head_i]
else:
self.models[head_i].load_weights(model_file)
self.model = self.models[head_i]
Expand Down

0 comments on commit d6bd728

Please sign in to comment.