Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question about pretrained weight path #63

Open
longdaothanh opened this issue May 27, 2024 · 9 comments
Open

Question about pretrained weight path #63

longdaothanh opened this issue May 27, 2024 · 9 comments

Comments

@longdaothanh
Copy link

longdaothanh commented May 27, 2024

I can see in the config file we need to set the path of ( pre-trained weight ), Not sure where to get it from.

Model -> sent1floods11_config.py
Need to Set path to pretrained backbone weights
pretrained_weights_path = ""

Best

@franrubi
Copy link

You can find the pre-trained weights file at this location: https://huggingface.co/ibm-nasa-geospatial/Prithvi-100M-sen1floods11/tree/main, is that what you are looking for?

@longdaothanh
Copy link
Author

You can find the pre-trained weights file at this location: https://huggingface.co/ibm-nasa-geospatial/Prithvi-100M-sen1floods11/tree/main, is that what you are looking for?

Hi, do you mean this file: (https://huggingface.co/ibm-nasa-geospatial/Prithvi-100M-sen1floods11/blob/main/sen1floods11_Prithvi_100M.pth) ?

One other question: can i also use this weight: (https://huggingface.co/ibm-nasa-geospatial/Prithvi-100M/blob/main/Prithvi_100M.pt) for fine-tuning?

Thank you so much

@franrubi
Copy link

yes to the first question, ".../sen1floods11_Prithvi_100M.pth" file keeps the weights learned for water-detection/flood downstream task.
And ".../Prithvi_100M.pth" contains the weights of the foundation model, on top of that you can finetune your own downstream task.

@longdaothanh
Copy link
Author

yes to the first question, ".../sen1floods11_Prithvi_100M.pth" file keeps the weights learned for water-detection/flood downstream task. And ".../Prithvi_100M.pth" contains the weights of the foundation model, on top of that you can finetune your own downstream task.

Let's say i want to finetune for my own task. Do I use this weight ".../Prithvi_100M.pth" and use my own dataset instead and follow the finetuning instruction from [README.md] ? I am new to this.

Thank you so much

@franrubi
Copy link

From paper https://arxiv.org/abs/2310.18660 we can read the following:
“We thoroughly examine and assess the effect of Prithvi’s pre-trained weights on downstream tasks. We compare learning curves between 1) fine-tuning the entire model, 2) fine-tuning solely the decoder for the downstream task, and 3) training the model without utilizing Prithvi’s pre-trained weights. Our experiments show that the pre-trained model accelerates the fine-tuning process compared to leveraging randomly initialized weights."
So option 2 is the best approach.
In MMSegmentation framework setup file, “pretrained_weights_path” sentence points to the path where Prithvi_100M.pth has to be placed.

@longdaothanh
Copy link
Author

From paper https://arxiv.org/abs/2310.18660 we can read the following: “We thoroughly examine and assess the effect of Prithvi’s pre-trained weights on downstream tasks. We compare learning curves between 1) fine-tuning the entire model, 2) fine-tuning solely the decoder for the downstream task, and 3) training the model without utilizing Prithvi’s pre-trained weights. Our experiments show that the pre-trained model accelerates the fine-tuning process compared to leveraging randomly initialized weights." So option 2 is the best approach. In MMSegmentation framework setup file, “pretrained_weights_path” sentence points to the path where Prithvi_100M.pth has to be placed.

2024-06-20 19:17:37,262 - mmseg - INFO - Set random seed to 2098588967, deterministic: False
load from C:/Masterarbeit/burn_scars_Prithvi_100M.pth
load checkpoint from local path: C:/Masterarbeit/burn_scars_Prithvi_100M.pth
The model and loaded state dict do not match exactly

unexpected key in source state_dict: backbone.cls_token, backbone.pos_embed, backbone.patch_embed.proj.weight, backbone.patch_embed.proj.bias, backbone.blocks.0.norm1.weight, backbone.blocks.0.norm1.bias, backbone.blocks.0.attn.qkv.weight, backbone.blocks.0.attn.qkv.bias, backbone.blocks.0.attn.proj.weight, backbone.blocks.0.attn.proj.bias, backbone.blocks.0.norm2.weight, backbone.blocks.0.norm2.bias, backbone.blocks.0.mlp.fc1.weight, backbone.blocks.0.mlp.fc1.bias, backbone.blocks.0.mlp.fc2.weight, backbone.blocks.0.mlp.fc2.bias, backbone.blocks.1.norm1.weight, backbone.blocks.1.norm1.bias, backbone.blocks.1.attn.qkv.weight, backbone.blocks.1.attn.qkv.bias, backbone.blocks.1.attn.proj.weight, backbone.blocks.1.attn.proj.bias, backbone.blocks.1.norm2.weight, backbone.blocks.1.norm2.bias, backbone.blocks.1.mlp.fc1.weight, backbone.blocks.1.mlp.fc1.bias, backbone.blocks.1.mlp.fc2.weight, backbone.blocks.1.mlp.fc2.bias, backbone.blocks.2.norm1.weight, backbone.blocks.2.norm1.bias, backbone.blocks.2.attn.qkv.weight, backbone.blocks.2.attn.qkv.bias, backbone.blocks.2.attn.proj.weight, backbone.blocks.2.attn.proj.bias, backbone.blocks.2.norm2.weight, backbone.blocks.2.norm2.bias, backbone.blocks.2.mlp.fc1.weight, backbone.blocks.2.mlp.fc1.bias, backbone.blocks.2.mlp.fc2.weight, backbone.blocks.2.mlp.fc2.bias, backbone.blocks.3.norm1.weight, backbone.blocks.3.norm1.bias, backbone.blocks.3.attn.qkv.weight, backbone.blocks.3.attn.qkv.bias, backbone.blocks.3.attn.proj.weight, backbone.blocks.3.attn.proj.bias, backbone.blocks.3.norm2.weight, backbone.blocks.3.norm2.bias, backbone.blocks.3.mlp.fc1.weight, backbone.blocks.3.mlp.fc1.bias, backbone.blocks.3.mlp.fc2.weight, backbone.blocks.3.mlp.fc2.bias, backbone.blocks.4.norm1.weight, backbone.blocks.4.norm1.bias, backbone.blocks.4.attn.qkv.weight, backbone.blocks.4.attn.qkv.bias, backbone.blocks.4.attn.proj.weight, backbone.blocks.4.attn.proj.bias, backbone.blocks.4.norm2.weight, backbone.blocks.4.norm2.bias, backbone.blocks.4.mlp.fc1.weight, backbone.blocks.4.mlp.fc1.bias, backbone.blocks.4.mlp.fc2.weight, backbone.blocks.4.mlp.fc2.bias, backbone.blocks.5.norm1.weight, backbone.blocks.5.norm1.bias, backbone.blocks.5.attn.qkv.weight, backbone.blocks.5.attn.qkv.bias, backbone.blocks.5.attn.proj.weight, backbone.blocks.5.attn.proj.bias, backbone.blocks.5.norm2.weight, backbone.blocks.5.norm2.bias, backbone.blocks.5.mlp.fc1.weight, backbone.blocks.5.mlp.fc1.bias, backbone.blocks.5.mlp.fc2.weight, backbone.blocks.5.mlp.fc2.bias, backbone.blocks.6.norm1.weight, backbone.blocks.6.norm1.bias, backbone.blocks.6.attn.qkv.weight, backbone.blocks.6.attn.qkv.bias, backbone.blocks.6.attn.proj.weight, backbone.blocks.6.attn.proj.bias, backbone.blocks.6.norm2.weight, backbone.blocks.6.norm2.bias, backbone.blocks.6.mlp.fc1.weight, backbone.blocks.6.mlp.fc1.bias, backbone.blocks.6.mlp.fc2.weight, backbone.blocks.6.mlp.fc2.bias, backbone.blocks.7.norm1.weight, backbone.blocks.7.norm1.bias, backbone.blocks.7.attn.qkv.weight, backbone.blocks.7.attn.qkv.bias, backbone.blocks.7.attn.proj.weight, backbone.blocks.7.attn.proj.bias, backbone.blocks.7.norm2.weight, backbone.blocks.7.norm2.bias, backbone.blocks.7.mlp.fc1.weight, backbone.blocks.7.mlp.fc1.bias, backbone.blocks.7.mlp.fc2.weight, backbone.blocks.7.mlp.fc2.bias, backbone.blocks.8.norm1.weight, backbone.blocks.8.norm1.bias, backbone.blocks.8.attn.qkv.weight, backbone.blocks.8.attn.qkv.bias, backbone.blocks.8.attn.proj.weight, backbone.blocks.8.attn.proj.bias, backbone.blocks.8.norm2.weight, backbone.blocks.8.norm2.bias, backbone.blocks.8.mlp.fc1.weight, backbone.blocks.8.mlp.fc1.bias, backbone.blocks.8.mlp.fc2.weight, backbone.blocks.8.mlp.fc2.bias, backbone.blocks.9.norm1.weight, backbone.blocks.9.norm1.bias, backbone.blocks.9.attn.qkv.weight, backbone.blocks.9.attn.qkv.bias, backbone.blocks.9.attn.proj.weight, backbone.blocks.9.attn.proj.bias, backbone.blocks.9.norm2.weight, backbone.blocks.9.norm2.bias, backbone.blocks.9.mlp.fc1.weight, backbone.blocks.9.mlp.fc1.bias, backbone.blocks.9.mlp.fc2.weight, backbone.blocks.9.mlp.fc2.bias, backbone.blocks.10.norm1.weight, backbone.blocks.10.norm1.bias, backbone.blocks.10.attn.qkv.weight, backbone.blocks.10.attn.qkv.bias, backbone.blocks.10.attn.proj.weight, backbone.blocks.10.attn.proj.bias, backbone.blocks.10.norm2.weight, backbone.blocks.10.norm2.bias, backbone.blocks.10.mlp.fc1.weight, backbone.blocks.10.mlp.fc1.bias, backbone.blocks.10.mlp.fc2.weight, backbone.blocks.10.mlp.fc2.bias, backbone.blocks.11.norm1.weight, backbone.blocks.11.norm1.bias, backbone.blocks.11.attn.qkv.weight, backbone.blocks.11.attn.qkv.bias, backbone.blocks.11.attn.proj.weight, backbone.blocks.11.attn.proj.bias, backbone.blocks.11.norm2.weight, backbone.blocks.11.norm2.bias, backbone.blocks.11.mlp.fc1.weight, backbone.blocks.11.mlp.fc1.bias, backbone.blocks.11.mlp.fc2.weight, backbone.blocks.11.mlp.fc2.bias, backbone.norm.weight, backbone.norm.bias, neck.fpn1.0.weight, neck.fpn1.0.bias, neck.fpn1.1.ln.weight, neck.fpn1.1.ln.bias, neck.fpn1.3.weight, neck.fpn1.3.bias, neck.fpn2.0.weight, neck.fpn2.0.bias, neck.fpn2.1.ln.weight, neck.fpn2.1.ln.bias, neck.fpn2.3.weight, neck.fpn2.3.bias, decode_head.conv_seg.weight, decode_head.conv_seg.bias, decode_head.convs.0.conv.weight, decode_head.convs.0.bn.weight, decode_head.convs.0.bn.bias, decode_head.convs.0.bn.running_mean, decode_head.convs.0.bn.running_var, decode_head.convs.0.bn.num_batches_tracked, auxiliary_head.conv_seg.weight, auxiliary_head.conv_seg.bias, auxiliary_head.convs.0.conv.weight, auxiliary_head.convs.0.bn.weight, auxiliary_head.convs.0.bn.bias, auxiliary_head.convs.0.bn.running_mean, auxiliary_head.convs.0.bn.running_var, auxiliary_head.convs.0.bn.num_batches_tracked, auxiliary_head.convs.1.conv.weight, auxiliary_head.convs.1.bn.weight, auxiliary_head.convs.1.bn.bias, auxiliary_head.convs.1.bn.running_mean, auxiliary_head.convs.1.bn.running_var, auxiliary_head.convs.1.bn.num_batches_tracked

missing keys in source state_dict: cls_token, pos_embed, patch_embed.proj.weight, patch_embed.proj.bias, blocks.0.norm1.weight, blocks.0.norm1.bias, blocks.0.attn.qkv.weight, blocks.0.attn.qkv.bias, blocks.0.attn.proj.weight, blocks.0.attn.proj.bias, blocks.0.norm2.weight, blocks.0.norm2.bias, blocks.0.mlp.fc1.weight, blocks.0.mlp.fc1.bias, blocks.0.mlp.fc2.weight, blocks.0.mlp.fc2.bias, blocks.1.norm1.weight, blocks.1.norm1.bias, blocks.1.attn.qkv.weight, blocks.1.attn.qkv.bias, blocks.1.attn.proj.weight, blocks.1.attn.proj.bias, blocks.1.norm2.weight, blocks.1.norm2.bias, blocks.1.mlp.fc1.weight, blocks.1.mlp.fc1.bias, blocks.1.mlp.fc2.weight, blocks.1.mlp.fc2.bias, blocks.2.norm1.weight, blocks.2.norm1.bias, blocks.2.attn.qkv.weight, blocks.2.attn.qkv.bias, blocks.2.attn.proj.weight, blocks.2.attn.proj.bias, blocks.2.norm2.weight, blocks.2.norm2.bias, blocks.2.mlp.fc1.weight, blocks.2.mlp.fc1.bias, blocks.2.mlp.fc2.weight, blocks.2.mlp.fc2.bias, blocks.3.norm1.weight, blocks.3.norm1.bias, blocks.3.attn.qkv.weight, blocks.3.attn.qkv.bias, blocks.3.attn.proj.weight, blocks.3.attn.proj.bias, blocks.3.norm2.weight, blocks.3.norm2.bias, blocks.3.mlp.fc1.weight, blocks.3.mlp.fc1.bias, blocks.3.mlp.fc2.weight, blocks.3.mlp.fc2.bias, blocks.4.norm1.weight, blocks.4.norm1.bias, blocks.4.attn.qkv.weight, blocks.4.attn.qkv.bias, blocks.4.attn.proj.weight, blocks.4.attn.proj.bias, blocks.4.norm2.weight, blocks.4.norm2.bias, blocks.4.mlp.fc1.weight, blocks.4.mlp.fc1.bias, blocks.4.mlp.fc2.weight, blocks.4.mlp.fc2.bias, blocks.5.norm1.weight, blocks.5.norm1.bias, blocks.5.attn.qkv.weight, blocks.5.attn.qkv.bias, blocks.5.attn.proj.weight, blocks.5.attn.proj.bias, blocks.5.norm2.weight, blocks.5.norm2.bias, blocks.5.mlp.fc1.weight, blocks.5.mlp.fc1.bias, blocks.5.mlp.fc2.weight, blocks.5.mlp.fc2.bias, blocks.6.norm1.weight, blocks.6.norm1.bias, blocks.6.attn.qkv.weight, blocks.6.attn.qkv.bias, blocks.6.attn.proj.weight, blocks.6.attn.proj.bias, blocks.6.norm2.weight, blocks.6.norm2.bias, blocks.6.mlp.fc1.weight, blocks.6.mlp.fc1.bias, blocks.6.mlp.fc2.weight, blocks.6.mlp.fc2.bias, blocks.7.norm1.weight, blocks.7.norm1.bias, blocks.7.attn.qkv.weight, blocks.7.attn.qkv.bias, blocks.7.attn.proj.weight, blocks.7.attn.proj.bias, blocks.7.norm2.weight, blocks.7.norm2.bias, blocks.7.mlp.fc1.weight, blocks.7.mlp.fc1.bias, blocks.7.mlp.fc2.weight, blocks.7.mlp.fc2.bias, blocks.8.norm1.weight, blocks.8.norm1.bias, blocks.8.attn.qkv.weight, blocks.8.attn.qkv.bias, blocks.8.attn.proj.weight, blocks.8.attn.proj.bias, blocks.8.norm2.weight, blocks.8.norm2.bias, blocks.8.mlp.fc1.weight, blocks.8.mlp.fc1.bias, blocks.8.mlp.fc2.weight, blocks.8.mlp.fc2.bias, blocks.9.norm1.weight, blocks.9.norm1.bias, blocks.9.attn.qkv.weight, blocks.9.attn.qkv.bias, blocks.9.attn.proj.weight, blocks.9.attn.proj.bias, blocks.9.norm2.weight, blocks.9.norm2.bias, blocks.9.mlp.fc1.weight, blocks.9.mlp.fc1.bias, blocks.9.mlp.fc2.weight, blocks.9.mlp.fc2.bias, blocks.10.norm1.weight, blocks.10.norm1.bias, blocks.10.attn.qkv.weight, blocks.10.attn.qkv.bias, blocks.10.attn.proj.weight, blocks.10.attn.proj.bias, blocks.10.norm2.weight, blocks.10.norm2.bias, blocks.10.mlp.fc1.weight, blocks.10.mlp.fc1.bias, blocks.10.mlp.fc2.weight, blocks.10.mlp.fc2.bias, blocks.11.norm1.weight, blocks.11.norm1.bias, blocks.11.attn.qkv.weight, blocks.11.attn.qkv.bias, blocks.11.attn.proj.weight, blocks.11.attn.proj.bias, blocks.11.norm2.weight, blocks.11.norm2.bias, blocks.11.mlp.fc1.weight, blocks.11.mlp.fc1.bias, blocks.11.mlp.fc2.weight, blocks.11.mlp.fc2.bias, norm.weight, norm.bias

Config file:

import os

custom_imports = dict(imports=["geospatial_fm"])

# base options
dist_params = dict(backend="nccl")
log_level = "INFO"
load_from = None
resume_from = None
cudnn_benchmark = True

dataset_type = "GeospatialDataset"

# TO BE DEFINED BY USER: data directory
data_root = "C:/Masterarbeit/hls-foundation-os/hls_burn_scars"

num_frames = 1
img_size = 224
num_workers = 4
samples_per_gpu = 4

img_norm_cfg = dict(
    means=[
        0.033349706741586264,
        0.05701185520536176,
        0.05889748132001316,
        0.2323245113436119,
        0.1972854853760658,
        0.11944914225186566,
    ],
    stds=[
        0.02269135568823774,
        0.026807560223070237,
        0.04004109844362779,
        0.07791732423672691,
        0.08708738838140137,
        0.07241979477437814,
    ],
)  # change the mean and std of all the bands

bands = [0, 1, 2, 3, 4, 5]
tile_size = 224
orig_nsize = 512
crop_size = (tile_size, tile_size)
img_suffix = "_merged.tif"
seg_map_suffix = ".mask.tif"
ignore_index = -1
image_nodata = -9999
image_nodata_replace = 0
image_to_float32 = True

# model
# TO BE DEFINED BY USER: model path
pretrained_weights_path = "C:/Masterarbeit/burn_scars_Prithvi_100M.pth"
num_layers = 12
patch_size = 16
embed_dim = 768
num_heads = 12
tubelet_size = 1
output_embed_dim = num_frames * embed_dim
max_intervals = 10000
evaluation_interval = 1000

# TO BE DEFINED BY USER: model path
experiment = "TEST"
project_dir = "C:\Masterarbeit"
work_dir = os.path.join(project_dir, experiment)
save_path = work_dir

save_path = work_dir
train_pipeline = [
    dict(type="LoadGeospatialImageFromFile", to_float32=image_to_float32),
    dict(type="LoadGeospatialAnnotations", reduce_zero_label=False),
    dict(type="BandsExtract", bands=bands),
    dict(type="RandomFlip", prob=0.5),
    dict(type="ToTensor", keys=["img", "gt_semantic_seg"]),
    # to channels first
    dict(type="TorchPermute", keys=["img"], order=(2, 0, 1)),
    dict(type="TorchNormalize", **img_norm_cfg),
    dict(type="TorchRandomCrop", crop_size=(tile_size, tile_size)),
    dict(
        type="Reshape",
        keys=["img"],
        new_shape=(len(bands), num_frames, tile_size, tile_size),
    ),
    dict(type="Reshape", keys=["gt_semantic_seg"], new_shape=(1, tile_size, tile_size)),
    dict(type="CastTensor", keys=["gt_semantic_seg"], new_type="torch.LongTensor"),
    dict(type="Collect", keys=["img", "gt_semantic_seg"]),
]
test_pipeline = [
    dict(type="LoadGeospatialImageFromFile", to_float32=image_to_float32),
    dict(type="BandsExtract", bands=bands),
    dict(type="ToTensor", keys=["img"]),
    # to channels first
    dict(type="TorchPermute", keys=["img"], order=(2, 0, 1)),
    dict(type="TorchNormalize", **img_norm_cfg),
    dict(
        type="Reshape",
        keys=["img"],
        new_shape=(len(bands), num_frames, -1, -1),
        look_up=dict({"2": 1, "3": 2}),
    ),
    dict(type="CastTensor", keys=["img"], new_type="torch.FloatTensor"),
    dict(
        type="CollectTestList",
        keys=["img"],
        meta_keys=[
            "img_info",
            "seg_fields",
            "img_prefix",
            "seg_prefix",
            "filename",
            "ori_filename",
            "img",
            "img_shape",
            "ori_shape",
            "pad_shape",
            "scale_factor",
            "img_norm_cfg",
        ],
    ),
]

CLASSES = ("Unburnt land", "Burn scar")

data = dict(
    samples_per_gpu=samples_per_gpu,
    workers_per_gpu=num_workers,
    train=dict(
        type=dataset_type,
        CLASSES=CLASSES,
        data_root=data_root,
        img_dir="training",
        ann_dir="training",
        img_suffix=img_suffix,
        seg_map_suffix=seg_map_suffix,
        pipeline=train_pipeline,
        ignore_index=-1,
    ),
    val=dict(
        type=dataset_type,
        CLASSES=CLASSES,
        data_root=data_root,
        img_dir="validation",
        ann_dir="validation",
        img_suffix=img_suffix,
        seg_map_suffix=seg_map_suffix,
        pipeline=test_pipeline,
        ignore_index=-1,
    ),
    test=dict(
        type=dataset_type,
        CLASSES=CLASSES,
        data_root=data_root,
        img_dir="validation",
        ann_dir="validation",
        img_suffix=img_suffix,
        seg_map_suffix=seg_map_suffix,
        pipeline=test_pipeline,
        ignore_index=-1,
    ),
)

optimizer = dict(type="Adam", lr=1.3e-05, betas=(0.9, 0.999))
optimizer_config = dict(grad_clip=None)
lr_config = dict(
    policy="poly",
    warmup="linear",
    warmup_iters=1500,
    warmup_ratio=1e-06,
    power=1.0,
    min_lr=0.0,
    by_epoch=False,
)
log_config = dict(
    interval=20,
    hooks=[
        dict(type="TextLoggerHook", by_epoch=False),
        dict(type="TensorboardLoggerHook", by_epoch=False),
    ],
)
checkpoint_config = dict(by_epoch=True, interval=10, out_dir=save_path)
evaluation = dict(
    interval=evaluation_interval,
    metric="mIoU",
    pre_eval=True,
    save_best="mIoU",
    by_epoch=False,
)

loss_func = dict(type="DiceLoss", use_sigmoid=False, loss_weight=1, ignore_index=-1)

runner = dict(type="IterBasedRunner", max_iters=max_intervals)
workflow = [("train", 1)]
norm_cfg = dict(type="BN", requires_grad=True)
model = dict(
    type="TemporalEncoderDecoder",
    frozen_backbone=False,
    backbone=dict(
        type="TemporalViTEncoder",
        pretrained=pretrained_weights_path,
        img_size=img_size,
        patch_size=patch_size,
        num_frames=num_frames,
        tubelet_size=tubelet_size,
        in_chans=len(bands),
        embed_dim=embed_dim,
        depth=12,
        num_heads=num_heads,
        mlp_ratio=4.0,
        norm_pix_loss=False,
    ),
    neck=dict(
        type="ConvTransformerTokensToEmbeddingNeck",
        embed_dim=embed_dim * num_frames,
        output_embed_dim=output_embed_dim,
        drop_cls_token=True,
        Hp=14,
        Wp=14,
    ),
    decode_head=dict(
        num_classes=len(CLASSES),
        in_channels=output_embed_dim,
        type="FCNHead",
        in_index=-1,
        channels=256,
        num_convs=1,
        concat_input=False,
        dropout_ratio=0.1,
        norm_cfg=dict(type="BN", requires_grad=True),
        align_corners=False,
        loss_decode=loss_func,
    ),
    auxiliary_head=dict(
        num_classes=len(CLASSES),
        in_channels=output_embed_dim,
        type="FCNHead",
        in_index=-1,
        channels=256,
        num_convs=2,
        concat_input=False,
        dropout_ratio=0.1,
        norm_cfg=dict(type="BN", requires_grad=True),
        align_corners=False,
        loss_decode=loss_func,
    ),
    train_cfg=dict(),
    test_cfg=dict(
        mode="slide",
        stride=(int(tile_size / 2), int(tile_size / 2)),
        crop_size=(tile_size, tile_size),
    ),
)
auto_resume = False

I followed the instruction on README. I got this error (tried with Prithvi_100M.pt, also the same thing). What did I do wrong? Thanks

@franrubi
Copy link

Hi, are you using 'burn_scars_Prithvi_100M.pth' file (1,2GB size), available from https://huggingface.co/ibm-nasa-geospatial/Prithvi-100M-burn-scar/tree/main?

@longdaothanh
Copy link
Author

Hi, are you using 'burn_scars_Prithvi_100M.pth' file (1,2GB size), available from https://huggingface.co/ibm-nasa-geospatial/Prithvi-100M-burn-scar/tree/main?

Yes, can still train but just found it interesing

@longdaothanh
Copy link
Author

Hi, are you using 'burn_scars_Prithvi_100M.pth' file (1,2GB size), available from https://huggingface.co/ibm-nasa-geospatial/Prithvi-100M-burn-scar/tree/main?

Hi, I have one question: Do you have any ideas how to get feature importance and permutation feature importance?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants