Skip to content

Commit

Permalink
Merge branch 'graphhdv2' of github.com:hyperdimensional-computing/tor…
Browse files Browse the repository at this point in the history
…chhd into graphhdv2
  • Loading branch information
pereverges committed Nov 14, 2023
2 parents 87368e9 + 63bafdd commit fd873b2
Showing 1 changed file with 39 additions and 8 deletions.
47 changes: 39 additions & 8 deletions GraphHD_v2/graphhd_list_basic_node_attr2.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,9 @@ def index_value(inner_tensor):
return torch.argmax(inner_tensor)

if len(x.x[0]) > 0:
indices_tensor = torch.stack([index_value(inner_tensor) for inner_tensor in x.x.unbind()])
indices_tensor = torch.stack(
[index_value(inner_tensor) for inner_tensor in x.x.unbind()]
)
node_attr = self.node_attr.weight[indices_tensor]
node_id_hvs = torchhd.bind(node_id_hvs, torchhd.permute(node_attr))

Expand All @@ -125,7 +127,9 @@ def index_value(inner_tensor):
i = x.edge_index[0][idx]
j = x.edge_index[1][idx]
if prev == i:
aux_hv = torchhd.bind(aux_hv, torchhd.bind(node_id_hvs[i], node_id_hvs[j]))
aux_hv = torchhd.bind(
aux_hv, torchhd.bind(node_id_hvs[i], node_id_hvs[j])
)
else:
prev = i
final_hv = torchhd.bundle(final_hv, aux_hv)
Expand Down Expand Up @@ -153,10 +157,11 @@ def index_value(inner_tensor):
model.add(samples_hv, samples.y)
# break


train_t = time.time() - train_t
accuracy = torchmetrics.Accuracy("multiclass", num_classes=graphs.num_classes)
f1 = torchmetrics.F1Score(num_classes=graphs.num_classes, average='macro', multiclass=True)
f1 = torchmetrics.F1Score(
num_classes=graphs.num_classes, average="macro", multiclass=True
)
# f1 = torchmetrics.F1Score("multiclass", num_classes=graphs.num_classes)

test_t = time.time()
Expand All @@ -182,10 +187,36 @@ def index_value(inner_tensor):

REPETITIONS = 1
RANDOMNESS = ["random"]
DATASET = [ 'MOLT-4','MOLT-4H','Mutagenicity','MUTAG',
'NCI1','NCI109','NCI-H23','NCI-H23H','OVCAR-8','OVCAR-8H','P388','P388H','PC-3','PC-3H','PTC_FM',
'PTC_FR','PTC_MM','PTC_MR','SF-295','SF-295H','SN12C','SN12CH',
'SW-620','SW-620H','UACC257','UACC257H','Yeast','YeastH']
DATASET = [
"MOLT-4",
"MOLT-4H",
"Mutagenicity",
"MUTAG",
"NCI1",
"NCI109",
"NCI-H23",
"NCI-H23H",
"OVCAR-8",
"OVCAR-8H",
"P388",
"P388H",
"PC-3",
"PC-3H",
"PTC_FM",
"PTC_FR",
"PTC_MM",
"PTC_MR",
"SF-295",
"SF-295H",
"SN12C",
"SN12CH",
"SW-620",
"SW-620H",
"UACC257",
"UACC257H",
"Yeast",
"YeastH",
]

# ,'BZR_MD','COX2','COX2_MD','DHFR','DHFR_MD','ER_MD', 'FRANKENSTEIN', 'NCI109','KKI','OHSU','Peking_1','PROTEINS','AIDS']
VSAS = ["FHRR"]
Expand Down

0 comments on commit fd873b2

Please sign in to comment.