Skip to content

Commit

Permalink
Merge pull request #1186 from jinyangturbo/dev-postgresql
Browse files Browse the repository at this point in the history
Create cifar100.py
  • Loading branch information
chrishkchris authored Jul 19, 2024
2 parents c2ffece + a1416c2 commit d842c74
Showing 1 changed file with 81 additions and 0 deletions.
81 changes: 81 additions & 0 deletions examples/cnn_ms/data/cifar100.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#

try:
import pickle
except ImportError:
import cPickle as pickle

import numpy as np
import os
import sys


def load_dataset(filepath):
with open(filepath, 'rb') as fd:
try:
cifar100 = pickle.load(fd, encoding='latin1')
except TypeError:
cifar100 = pickle.load(fd)
image = cifar100['data'].astype(dtype=np.uint8)
image = image.reshape((-1, 3, 32, 32))
label = np.asarray(cifar100['fine_labels'], dtype=np.uint8)
label = label.reshape(label.size, 1)
return image, label


def load_train_data(dir_path='/tmp/cifar-100-python'):
images, labels = load_dataset(check_dataset_exist(dir_path + "/train"))
return np.array(images, dtype=np.float32), np.array(labels, dtype=np.int32)


def load_test_data(dir_path='/tmp/cifar-100-python'):
images, labels = load_dataset(check_dataset_exist(dir_path + "/test"))
return np.array(images, dtype=np.float32), np.array(labels, dtype=np.int32)


def check_dataset_exist(dirpath):
if not os.path.exists(dirpath):
print(
'Please download the cifar100 dataset using python data/download_cifar100.py'
)
sys.exit(0)
return dirpath


def normalize(train_x, val_x):
mean = [0.4914, 0.4822, 0.4465]
std = [0.2023, 0.1994, 0.2010]
train_x /= 255
val_x /= 255
for ch in range(0, 2):
train_x[:, ch, :, :] -= mean[ch]
train_x[:, ch, :, :] /= std[ch]
val_x[:, ch, :, :] -= mean[ch]
val_x[:, ch, :, :] /= std[ch]
return train_x, val_x


def load():
train_x, train_y = load_train_data()
val_x, val_y = load_test_data()
train_x, val_x = normalize(train_x, val_x)
train_y = train_y.flatten()
val_y = val_y.flatten()
return train_x, train_y, val_x, val_y

0 comments on commit d842c74

Please sign in to comment.