image_test.py
import argparseimport numpy as npimport sysimport osimport csvfrom imagenet_test_base import TestKitimport torchclass TestTorch(TestKit): ???def __init__(self): ???????super(TestTorch, self).__init__() ???????self.truth[‘tensorflow‘][‘inception_v3‘] = [(22, 9.6691055), (24, 4.3524747), (25, 3.5957973), (132, 3.5657473), (23, 3.346283)] ???????self.truth[‘keras‘][‘inception_v3‘] = [(21, 0.93430489), (23, 0.002883445), (131, 0.0014781791), (24, 0.0014518998), (22, 0.0014435351)] ???????self.model = self.MainModel.KitModel(self.args.w) ???????self.model.eval() ???def preprocess(self, image_path): ???????x = super(TestTorch, self).preprocess(image_path) ???????x = np.transpose(x, (2, 0, 1)) ???????x = np.expand_dims(x, 0).copy() ???????self.data = torch.from_numpy(x) ???????self.data = torch.autograd.Variable(self.data, requires_grad = False) ???def print_result(self, image_name, top1, top5): ???????predict = self.model(self.data) ???????predict = predict.data.numpy() ???????return super(TestTorch, self).print_result(predict, image_name, top1, top5) ???def print_intermediate_result(self, layer_name, if_transpose=False): ???????intermediate_output = self.model.test.data.numpy() ???????super(TestTorch, self).print_intermediate_result(intermediate_output, if_transpose) ???def inference(self, images): ???????with open(images) as fp_images: ???????????images_file = csv.reader(fp_images, delimiter=‘\n‘) ???????????top1 = 0.0 ???????????top5 = 0.0 ???????????image_count = 0 ???????????for image_name in images_file: ???????????????image_count += 1 ???????????????image_path = "../data/imagenet/small_imagenet/"+image_name[0] ???????????????self.preprocess(image_path) ???????????????temp1, temp5 = self.print_result(image_name[0], top1, top5) ???????????????top1 = temp1 ???????????????top5 = temp5 ???????print("top1‘s accuracy : %f"%(top1/image_count)) ???????print("top5‘s accuracy : %f"%(top5/image_count)) ???????# self.print_intermediate_result(None, False) # self.test_truth() ???def dump(self, path=None): ???????if path is None: path = self.args.dump ???????torch.save(self.model, path) ???????print(‘PyTorch model file is saved as [{}], generated by [{}.py] and [{}].‘.format( ?????????????path, self.args.n, self.args.w))if __name__==‘__main__‘: ???tester = TestTorch() ???if tester.args.dump: ???????tester.dump() ???else: ???????tester.inference(tester.args.image)
image_test_base.py:
请见上传的代码。 下载地址:https://files.cnblogs.com/files/jzcbest1016/imagenet_test_base.py.tar.gz
执行py文件时,需要终端输入参数:
parser = argparse.ArgumentParser() ???????parser.add_argument(‘-p‘, ‘--preprocess‘, type=_text_type, help=‘Model Preprocess Type‘) ??# pytorch的测试程序, 这里为image_test.py ???????parser.add_argument(‘-n‘, type=_text_type, default=‘kit_imagenet‘, ?????????????????????????????help=‘Network structure file name.‘) ??# 模型结构测试文件 ???????parser.add_argument(‘-s‘, type=_text_type, help=‘Source Framework Type‘, ???????????????????????????choices=self.truth.keys()) ??????????# 框架类型:pytorch,tensorflow... ???????parser.add_argument(‘-w‘, type=_text_type, required=True, ???????????????????????????help=‘Network weights file name‘) ??#模型结构文件 ???????parser.add_argument(‘--image‘, ‘-i‘, ???????????????????????????type=_text_type, help=‘Test image path.‘, ???????????????????????????default="../data/file_list.txt" ????#图像路径 ???????) ???????parser.add_argument(‘-l‘, ‘--label‘, ???????????????????????????type=_text_type, ???????????????????????????default=‘../data/val.txt‘, ???????????????????????????help=‘Path of label.‘) ??#测试集类别 ???????parser.add_argument(‘--dump‘, ???????????type=_text_type, ???????????default=None, ???????????help=‘Target model path.‘) ?# 转化的目标模型文件的保存路径 ???????parser.add_argument(‘--detect‘, ???????????type=_text_type, ???????????default=None, ???????????help=‘Model detection result path.‘) ???????# tensorflow dump tag ???????parser.add_argument(‘--dump_tag‘, ???????????type=_text_type, ???????????default=None, ???????????help=‘Tensorflow model dump type‘, ???????????choices=[‘SERVING‘, ‘TRAINING‘])
pytorch imagenet测试代码
原文地址:https://www.cnblogs.com/jzcbest1016/p/9780356.html