###実現したいこと
今deeplearningを用いて画像処理の研究を行っています。
そこで、VNetというネットワークを用いてSegmentationを行っているのですが、
リサイズ、正規化後の(処理途中の)画像を取り出してみたいのですができません。
コードで言うと、
getNumpyData(self,dat,method)
の最後のほうのret[key]が取り出したい部分です。
よろしくお願いいたします。
###ソースコードの内容
このDataManager.pyでは
データの読み込み、画像のリサイズ、正規化
そしてテスト後の結果画像の書き出しを行っています。
最後に結果画像を書き出している部分があるのでそこと同じようにやればできるのではないかと考えたのですが
できませんでした。
###扱っている画像
3次元のCT画像
濃度分解能:12[bit]
###試したこと
writeResultFromNumpyLabelの中で
3次元のラベルを出力している部分があるので同じように書いて見ましたが
だめでした。
具体的には
DataManager.pyのreturn retの前に
writer = sitk.ImageFileWriter()
writer.SetFileName('test.raw')
writer.Execute(ret[key])
と書いたところ
NotImplementedError: wrong number or arguments for overloaded function 'ImageFileWriter_Execute'.
Possible C/C++ prototypws are:
itk::simple::ImageFileWriter::Execute(itk::simple::Image const &)
itk::simple::IMageFileWriter::Execute(itk::simple::Image const &,std::string const &,bool)
というエラーが出ました。
Python
1#DataManager.py 2 3import numpy as np 4import SimpleITK as sitk 5from os import listdir 6from os.path import isfile, join, splitext 7 8class DataManager(object): 9 params=None 10 srcFolder=None 11 resultsDir=None 12 13 fileList=None 14 gtList=None 15 16 sitkImages=None 17 sitkGT=None 18 meanIntensityTrain = None 19 20 def __init__(self,srcFolder,resultsDir,parameters): 21 self.params=parameters 22 self.srcFolder=srcFolder 23 self.resultsDir=resultsDir 24 25 def createImageFileList(self): 26 self.fileList = [f for f in listdir(self.srcFolder) if isfile(join(self.srcFolder, f)) and 'segmentation' not in f and 'raw' not in f] 27 print 'FILE LIST: ' + str(self.fileList) 28 29 30 def createGTFileList(self): 31 self.gtList=list() 32 for f in self.fileList: 33 filename, ext = splitext(f) 34 self.gtList.append(join(filename + '_segmentation' + ext)) 35 36 37 def loadImages(self): 38 self.sitkImages=dict() 39 rescalFilt=sitk.RescaleIntensityImageFilter() 40 rescalFilt.SetOutputMaximum(1) 41 rescalFilt.SetOutputMinimum(0) 42 43 stats = sitk.StatisticsImageFilter() 44 m = 0. 45 for f in self.fileList: 46 self.sitkImages[f]=rescalFilt.Execute(sitk.Cast(sitk.ReadImage(join(self.srcFolder, f)),sitk.sitkFloat32)) 47 stats.Execute(self.sitkImages[f]) 48 m += stats.GetMean() 49 50 self.meanIntensityTrain=m/len(self.sitkImages) 51 52 53 def loadGT(self): 54 self.sitkGT=dict() 55 56 for f in self.gtList: 57 self.sitkGT[f]=sitk.Cast(sitk.ReadImage(join(self.srcFolder, f))>0.5,sitk.sitkFloat32) 58 59 60 61 def loadTrainingData(self): 62 self.createImageFileList() 63 self.createGTFileList() 64 self.loadImages() 65 self.loadGT() 66 67 68 def loadTestData(self): 69 self.createImageFileList() 70 self.loadImages() 71 72 def getNumpyImages(self): 73 dat = self.getNumpyData(self.sitkImages,sitk.sitkLinear) 74 return dat 75 76 77 def getNumpyGT(self): 78 dat = self.getNumpyData(self.sitkGT,sitk.sitkLinear) 79 80 for key in dat: 81 dat[key] = (dat[key]>0.5).astype(dtype=np.float32) 82 83 return dat 84 85 86 def getNumpyData(self,dat,method): 87 ret=dict() 88 for key in dat: 89 ret[key] = np.zeros([self.params['VolSize'][0], self.params['VolSize'][1], self.params['VolSize'][2]], dtype=np.float32) 90 91 img=dat[key] 92 93 #we rotate the image according to its transformation using the direction and according to the final spacing we want 94 factor = np.asarray(img.GetSpacing()) / [self.params['dstRes'][0], self.params['dstRes'][1], 95 self.params['dstRes'][2]] 96 97 factorSize = np.asarray(img.GetSize() * factor, dtype=float) 98 99 newSize = np.max([factorSize, self.params['VolSize']], axis=0) 100 101 newSize = newSize.astype(dtype=int) 102 103 T=sitk.AffineTransform(3) 104 T.SetMatrix(img.GetDirection()) 105 106 resampler = sitk.ResampleImageFilter() 107 resampler.SetReferenceImage(img) 108 resampler.SetOutputSpacing([self.params['dstRes'][0], self.params['dstRes'][1], self.params['dstRes'][2]]) 109 resampler.SetSize(newSize) 110 resampler.SetInterpolator(method) 111 if self.params['normDir']: 112 resampler.SetTransform(T.GetInverse()) 113 114 imgResampled = resampler.Execute(img) 115 116 117 imgCentroid = np.asarray(newSize, dtype=float) / 2.0 118 119 imgStartPx = (imgCentroid - self.params['VolSize'] / 2.0).astype(dtype=int) 120 121 regionExtractor = sitk.RegionOfInterestImageFilter() 122 regionExtractor.SetSize(list(self.params['VolSize'].astype(dtype=int))) 123 regionExtractor.SetIndex(list(imgStartPx)) 124 125 imgResampledCropped = regionExtractor.Execute(imgResampled) 126 127 ret[key] = np.transpose(sitk.GetArrayFromImage(imgResampledCropped).astype(dtype=float), [2, 1, 0]) 128 129 return ret 130 131 132 def writeResultsFromNumpyLabel(self,result,key): 133 img = self.sitkImages[key] 134 135 toWrite=sitk.Image(img.GetSize()[0],img.GetSize()[1],img.GetSize()[2],sitk.sitkFloat32) 136 137 factor = np.asarray(img.GetSpacing()) / [self.params['dstRes'][0], self.params['dstRes'][1], 138 self.params['dstRes'][2]] 139 140 factorSize = np.asarray(img.GetSize() * factor, dtype=float) 141 142 newSize = np.max([factorSize, self.params['VolSize']], axis=0) 143 144 newSize = newSize.astype(dtype=int) 145 146 T = sitk.AffineTransform(3) 147 T.SetMatrix(img.GetDirection()) 148 149 resampler = sitk.ResampleImageFilter() 150 resampler.SetReferenceImage(img) 151 resampler.SetOutputSpacing([self.params['dstRes'][0], self.params['dstRes'][1], self.params['dstRes'][2]]) 152 resampler.SetSize(newSize) 153 resampler.SetInterpolator(sitk.sitkNearestNeighbor) 154 155 if self.params['normDir']: 156 resampler.SetTransform(T.GetInverse()) 157 158 toWrite = resampler.Execute(toWrite) 159 160 imgCentroid = np.asarray(newSize, dtype=float) / 2.0 161 162 imgStartPx = (imgCentroid - self.params['VolSize'] / 2.0).astype(dtype=int) 163 164 for dstX, srcX in zip(range(0, result.shape[0]), range(imgStartPx[0],int(imgStartPx[0]+self.params['VolSize'][0]))): 165 for dstY, srcY in zip(range(0, result.shape[1]), range(imgStartPx[1], int(imgStartPx[1]+self.params['VolSize'][1]))): 166 for dstZ, srcZ in zip(range(0, result.shape[2]), range(imgStartPx[2], int(imgStartPx[2]+self.params['VolSize'][2]))): 167 try: 168 toWrite.SetPixel(int(srcX),int(srcY),int(srcZ),float(result[dstX,dstY,dstZ])) 169 except: 170 pass 171 172 173 resampler.SetOutputSpacing([img.GetSpacing()[0], img.GetSpacing()[1], img.GetSpacing()[2]]) 174 resampler.SetSize(img.GetSize()) 175 176 if self.params['normDir']: 177 resampler.SetTransform(T) 178 179 toWrite = resampler.Execute(toWrite) 180 181 thfilter=sitk.BinaryThresholdImageFilter() 182 thfilter.SetInsideValue(1) 183 thfilter.SetOutsideValue(0) 184 thfilter.SetLowerThreshold(0.5) 185 toWrite = thfilter.Execute(toWrite) 186 187 #connected component analysis (better safe than sorry) 188 189 cc = sitk.ConnectedComponentImageFilter() 190 toWritecc = cc.Execute(sitk.Cast(toWrite,sitk.sitkUInt8)) 191 192 arrCC=np.transpose(sitk.GetArrayFromImage(toWritecc).astype(dtype=float), [2, 1, 0]) 193 194 lab=np.zeros(int(np.max(arrCC)+1),dtype=float) 195 196 for i in range(1,int(np.max(arrCC)+1)): 197 lab[i]=np.sum(arrCC==i) 198 199 activeLab=np.argmax(lab) 200 201 toWrite = (toWritecc==activeLab) 202 203 toWrite = sitk.Cast(toWrite,sitk.sitkUInt8) 204 205 writer = sitk.ImageFileWriter() 206 filename, ext = splitext(key) 207 #print join(self.resultsDir, filename + '_result' + ext) 208 writer.SetFileName(join(self.resultsDir, filename + '_result' + ext)) 209 writer.Execute(toWrite)