Skip to content

Commit

Permalink
Create the folder for TED application in the healthcare model zoo
Browse files Browse the repository at this point in the history
  • Loading branch information
hugy718 committed Nov 16, 2024
1 parent 5482fe7 commit 426c692
Show file tree
Hide file tree
Showing 3 changed files with 307 additions and 0 deletions.
11 changes: 11 additions & 0 deletions examples/healthcare/application/TED_CT_Detection/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Convolutional Prototype Learning

We have successfully applied the idea of prototype loss in various medical image classification task to improve performance, for example detection thyroid eye disease from CT images. Here we provide the implementation of the convolution prototype model in Singa. Due to data privacy, we are not able to release the CT image dataset used. The training scripts `./train.py` demonstrate how to apply this model on cifar-10 dataset.

## run

At Singa project root directory `python examples/healthcare/application/TED_CT_Detection/train.py`

## reference

[Robust Classification with Convolutional Prototype Learning](https://arxiv.org/abs/1805.03438)
115 changes: 115 additions & 0 deletions examples/healthcare/application/TED_CT_Detection/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from singa import layer
from singa import model
import singa.tensor as tensor
from singa import autograd
from singa.tensor import Tensor

class CPLayer(layer.Layer):
def __init__(self, prototype_count=2, temp=10.0):
super(CPLayer, self).__init__()
self.prototype_count = prototype_count
self.temp = temp

def initialize(self, x):
self.feature_dim = x.shape[1]
self.prototype = tensor.random((self.feature_dim, self.prototype_count), device = x.device)

def forward(self, feat):
self.device_check(feat, self.prototype)
self.dtype_check(feat, self.prototype)

# print(f'feat: {feat.shape}')
# print(f'prototype: {self.prototype.shape}')
feat_sq = autograd.mul(feat, feat)
feat_sq_sum = autograd.reduce_sum(feat_sq, axes=[1], keepdims=1)
# print(f'feat sq sum: {feat_sq_sum.shape}')
# print(f'feature_dim {self.feature_dim}')
feat_sq_sum_tile = autograd.tile(feat_sq_sum, repeats=[1, self.feature_dim])

# print(f'feat sq sum tile: {feat_sq_sum_tile.shape}')
prototype_sq = autograd.mul(self.prototype, self.prototype)
prototype_sq_sum = autograd.reduce_sum(prototype_sq, axes=[0], keepdims=1)
# print(f'prototype sq sum: {prototype_sq_sum.shape}')
prototype_sq_sum_tile = autograd.tile(prototype_sq_sum, repeats=feat.shape[0])
# print(f'prototype sq sum tile: {prototype_sq_sum_tile.shape}')

cross_term = autograd.matmul(feat, self.prototype)
cross_term_scale = Tensor(shape=cross_term.shape, device=cross_term.device, requires_grad=False).set_value(-2)
cross_term_scaled = autograd.mul(cross_term, cross_term_scale)
# print(f'cross term scaled: {cross_term_scaled.shape}')

dist = autograd.add(feat_sq_sum_tile, prototype_sq_sum_tile)
dist = autograd.add(dist, cross_term_scaled)
# print(f'dist: {dist.shape}')

logits_coeff = tensor.ones((feat.shape[0], self.prototype.shape[1]), device=feat.device)*-1.0/self.temp
logits_coeff.requires_grad = False
logits = autograd.mul(logits_coeff, dist)


return logits

def get_params(self):
return {self.prototype.name: self.prototype}

def set_params(self, parameters):
self.prototype.copy_from(parameters[self.prototype.name])


class CPL(model.Model):

def __init__(self, backbone: model.Model, prototype_count=2, lamb = 0.5, temp =10, label=None, prototype_weight=None):
super(CPL, self).__init__()
# config
self.lamb = lamb
self.prototype_weight = prototype_weight
self.prototype_label = label

# layer
self.backbone = backbone
self.cplayer = CPLayer(prototype_count=prototype_count, temp=temp)
# optimizer
self.softmax_cross_entropy = layer.SoftMaxCrossEntropy()


def forward(self, x):
feat = self.backbone.forward(x)
logits =self.cplayer(feat)
return logits

def train_one_batch(self, x, y, dist_option, spars):
out = self.forward(x)
# print(f'out: {out.shape}')
loss = self.softmax_cross_entropy(out, y)
# print(f'loss: {loss.shape}')
self.optimizer(loss)
return out, loss


def set_optimizer(self, optimizer):
self.optimizer = optimizer


def create_model(backbone, prototype_count=2, lamb = 0.5, temp = 10.0):
model = CPL(backbone, prototype_count=prototype_count, lamb=lamb, temp=temp)
return model

__all__ = ['CPL', 'create_model']
181 changes: 181 additions & 0 deletions examples/healthcare/application/TED_CT_Detection/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#

from singa import device
from singa import opt
from singa import tensor
import argparse
import numpy as np
import time
from PIL import Image

import sys
sys.path.append('.')
print(sys.path)

import examples.cnn.model.cnn as cnn
from examples.cnn.data import cifar10
import model as cpl

def accuracy(pred, target):
# y is network output to be compared with ground truth (int)
y = np.argmax(pred, axis=1)
a = y == target
correct = np.array(a, "int").sum()
return correct


def resize_dataset(x, image_size):
num_data = x.shape[0]
dim = x.shape[1]
X = np.zeros(shape=(num_data, dim, image_size, image_size),
dtype=np.float32)
for n in range(0, num_data):
for d in range(0, dim):
X[n, d, :, :] = np.array(Image.fromarray(x[n, d, :, :]).resize(
(image_size, image_size), Image.BILINEAR),
dtype=np.float32)
return X


def run(local_rank,
max_epoch,
batch_size,
sgd,
graph,
verbosity,
dist_option='plain',
spars=None):
dev = device.create_cuda_gpu_on(local_rank)
dev.SetRandSeed(0)
np.random.seed(0)

train_x, train_y, val_x, val_y = cifar10.load()

num_channels = train_x.shape[1]
image_size = train_x.shape[2]
data_size = np.prod(train_x.shape[1:train_x.ndim]).item()
num_classes = (np.max(train_y) + 1).item()

backbone = cnn.create_model(num_channels = num_channels, num_classes = num_classes)
model = cpl.create_model(backbone, prototype_count=10, lamb=0.5, temp=10)


if backbone.dimension == 4:
tx = tensor.Tensor((batch_size, num_channels, backbone.input_size, backbone.input_size), dev)
train_x = resize_dataset(train_x, backbone.input_size)
val_x = resize_dataset(val_x, backbone.input_size)
elif backbone.dimension == 2:
tx = tensor.Tensor((batch_size, data_size), dev)
np.reshape(train_x, (train_x.shape[0], -1))
np.reshape(val_x, (val_x.shape[0], -1))

ty = tensor.Tensor((batch_size,), dev, tensor.int32)
num_train_batch = train_x.shape[0] // batch_size
num_val_batch = val_x.shape[0] // batch_size
idx = np.arange(train_x.shape[0], dtype=np.int32)

model.set_optimizer(sgd)
model.compile([tx], is_train=True, use_graph=graph, sequential=True)
dev.SetVerbosity(verbosity)

for epoch in range(max_epoch):
print(f'Epoch {epoch}')
start_time = time.time()
np.random.shuffle(idx)

train_correct = np.zeros(shape=[1], dtype=np.float32)
test_correct = np.zeros(shape=[1], dtype=np.float32)
train_loss = np.zeros(shape=[1], dtype=np.float32)

model.train()
for b in range(num_train_batch):
x = train_x[idx[b*batch_size:(b+1)*batch_size]]
y = train_y[idx[b * batch_size:(b+1)*batch_size]]
tx.copy_from_numpy(x)
ty.copy_from_numpy(y)

out, loss = model(tx, ty, dist_option, spars)
train_correct += accuracy(tensor.to_numpy(out), y)
train_loss += tensor.to_numpy(loss)[0]
print('Training loss = %f, training accuracy = %f' %
(train_loss, train_correct /
(num_train_batch * batch_size)),
flush=True)

model.eval()
for b in range(num_val_batch):
x = val_x[b*batch_size:(b+1)*batch_size]
y = val_y[b*batch_size:(b+1)*batch_size]

tx.copy_from_numpy(x)
ty.copy_from_numpy(y)

out_test = model(tx, ty, dist_option='fp32', spars=None)
test_correct += accuracy(tensor.to_numpy(out_test), y)


if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Train a CPL model')
parser.add_argument('-m',
'--max-epoch',
default=20,
type=int,
help='maximum epochs',
dest='max_epoch')
parser.add_argument('-b',
'--batch-size',
default=64,
type=int,
help='batch size',
dest='batch_size')
parser.add_argument('-l',
'--learning-rate',
default=0.005,
type=float,
help='initial learning rate',
dest='lr')
parser.add_argument('-i',
'--device-id',
default=0,
type=int,
help='which GPU to use',
dest='device_id')
parser.add_argument('-g',
'--disable-graph',
default='True',
action='store_false',
help='disable graph',
dest='graph')
parser.add_argument('-v',
'--log-verbosity',
default=0,
type=int,
help='logging verbosity',
dest='verbosity')
args = parser.parse_args()
print(args)

sgd = opt.SGD(lr=args.lr, momentum=0.9, weight_decay=1e-5)
run(args.device_id,
args.max_epoch,
args.batch_size,
sgd,
args.graph,
args.verbosity)

0 comments on commit 426c692

Please sign in to comment.