From 3150da8bf29ae9e05aa0932dede8e98a99f86f96 Mon Sep 17 00:00:00 2001 From: Bowenyin <52337138+Yinbowen-chn@users.noreply.github.com> Date: Fri, 9 Dec 2022 21:37:02 +0800 Subject: [PATCH] realeased_maincode --- dataset.py | 241 +++++++ evaltools/__pycache__/metrics.cpython-38.pyc | Bin 0 -> 14688 bytes evaltools/eval.py | 162 +++++ evaltools/metrics.py | 442 +++++++++++++ evaltools/valid_eval.sh | 8 + main.py | 184 ++++++ model/CamoFormer.py | 104 +++ model/__pycache__/CamoFormer.cpython-38.pyc | Bin 0 -> 3249 bytes .../__pycache__/decoder_p.cpython-38.pyc | Bin 0 -> 10868 bytes model/decoder/decoder_p.py | 296 +++++++++ .../__pycache__/pvtv2_encoder.cpython-38.pyc | Bin 0 -> 15744 bytes .../__pycache__/swin_encoder.cpython-38.pyc | Bin 0 -> 20350 bytes model/encoder/pvtv2_encoder.py | 445 +++++++++++++ model/encoder/swin_encoder.py | 608 ++++++++++++++++++ test_eval.sh | 8 + train.sh | 12 + 16 files changed, 2510 insertions(+) create mode 100644 dataset.py create mode 100644 evaltools/__pycache__/metrics.cpython-38.pyc create mode 100644 evaltools/eval.py create mode 100644 evaltools/metrics.py create mode 100644 evaltools/valid_eval.sh create mode 100644 main.py create mode 100644 model/CamoFormer.py create mode 100644 model/__pycache__/CamoFormer.cpython-38.pyc create mode 100644 model/decoder/__pycache__/decoder_p.cpython-38.pyc create mode 100644 model/decoder/decoder_p.py create mode 100644 model/encoder/__pycache__/pvtv2_encoder.cpython-38.pyc create mode 100644 model/encoder/__pycache__/swin_encoder.cpython-38.pyc create mode 100644 model/encoder/pvtv2_encoder.py create mode 100644 model/encoder/swin_encoder.py create mode 100644 test_eval.sh create mode 100644 train.sh diff --git a/dataset.py b/dataset.py new file mode 100644 index 0000000..d56ff04 --- /dev/null +++ b/dataset.py @@ -0,0 +1,241 @@ +#!/usr/bin/python3 +#coding=utf-8 + +import os +import cv2 +import torch +import numpy as np +from torch.utils.data import Dataset +from PIL import Image +import random + +########################### Data Augmentation ########################### +class Normalize(object): + def __init__(self, mean, std): + self.mean = mean + self.std = std + + def __call__(self, image, mask=None, body=None, detail=None): + image = (image - self.mean)/self.std + if mask is None: + return image + return image, mask/255 + +class RandomCrop(object): + def __call__(self, image, mask=None, body=None, detail=None): + H,W,_ = image.shape + randw = np.random.randint(W/8) + randh = np.random.randint(H/8) + offseth = 0 if randh == 0 else np.random.randint(randh) + offsetw = 0 if randw == 0 else np.random.randint(randw) + p0, p1, p2, p3 = offseth, H+offseth-randh, offsetw, W+offsetw-randw + if mask is None: + return image[p0:p1,p2:p3, :] + return image[p0:p1,p2:p3, :], mask[p0:p1,p2:p3] + +class RandomFlip(object): + def __call__(self, image, mask=None, body=None, detail=None): + if np.random.randint(2)==0: + if mask is None: + return image[:,::-1,:].copy() + return image[:,::-1,:].copy(), mask[:, ::-1].copy() + else: + if mask is None: + return image + return image, mask + +class Resize(object): + def __init__(self, H, W): + self.H = H + self.W = W + + def __call__(self, image, mask=None, body=None, detail=None): + image = cv2.resize(image, dsize=(self.W, self.H), interpolation=cv2.INTER_LINEAR) + if mask is None: + return image + mask = cv2.resize( mask, dsize=(self.W, self.H), interpolation=cv2.INTER_LINEAR) + body = cv2.resize( body, dsize=(self.W, self.H), interpolation=cv2.INTER_LINEAR) + detail= cv2.resize( detail, dsize=(self.W, self.H), interpolation=cv2.INTER_LINEAR) + return image, mask + +class RandomRotate(object): + def rotate(self, x, random_angle, mode='image'): + + if mode == 'image': + H, W, _ = x.shape + else: + H, W = x.shape + + random_angle %= 360 + image_change = cv2.getRotationMatrix2D((W/2, H/2), random_angle, 1) + image_rotated = cv2.warpAffine(x, image_change, (W, H)) + + angle_crop = random_angle % 180 + if random_angle > 90: + angle_crop = 180 - angle_crop + theta = angle_crop * np.pi / 180 + hw_ratio = float(H) / float(W) + tan_theta = np.tan(theta) + numerator = np.cos(theta) + np.sin(theta) * np.tan(theta) + r = hw_ratio if H > W else 1 / hw_ratio + denominator = r * tan_theta + 1 + crop_mult = numerator / denominator + + w_crop = int(crop_mult * W) + h_crop = int(crop_mult * H) + x0 = int((W - w_crop) / 2) + y0 = int((H - h_crop) / 2) + crop_image = lambda img, x0, y0, W, H: img[y0:y0+h_crop, x0:x0+w_crop ] + output = crop_image(image_rotated, x0, y0, w_crop, h_crop) + + return output + + def __call__(self, image, mask=None, body=None, detail=None): + + do_seed = np.random.randint(0,3) + if do_seed != 2: + if mask is None: + return image + return image, mask + + random_angle = np.random.randint(-10, 10) + image = self.rotate(image, random_angle, 'image') + + if mask is None: + return image + mask = self.rotate(mask, random_angle, 'mask') + + return image, mask + + +class ColorEnhance(object): + def __init__(self): + + #A:0.5~1.5, G: 5-15 + self.A = np.random.randint(7, 13, 1)[0]/10 + self.G = np.random.randint(7, 13, 1)[0] + + + def __call__(self, image, mask=None, body=None, detail=None): + + do_seed = np.random.randint(0,3) + if do_seed > 1:#1: # 1/3 + H, W, _ = image.shape + dark_matrix = np.zeros([H, W, _], image.dtype) + image = cv2.addWeighted(image, self.A, dark_matrix, 1-self.A, self.G) + else: + pass + + if mask is None: + return image + return image, mask + +class GaussNoise(object): + def __init__(self): + self.Mean = 0 + self.Var = 0.001 + + def __call__(self, image, mask=None, body=None, detail=None): + H, W, _ = image.shape + do_seed = np.random.randint(0,3) + + + if do_seed == 0: #1: # 1/3 + factor = np.random.randint(0,10) + noise = np.random.normal(self.Mean, self.Var ** 0.5, image.shape) * factor + noise = noise.astype(image.dtype) + image = cv2.add(image, noise) + else: + pass + + if mask is None: + return image + return image, mask + + + +class ToTensor(object): + def __call__(self, image, mask=None, body=None, detail=None): + image = torch.from_numpy(image) + image = image.permute(2, 0, 1) + if mask is None: + return image + mask = torch.from_numpy(mask) + return image, mask + + +########################### Config File ########################### +class Config(object): + def __init__(self, **kwargs): + self.kwargs = kwargs + self.mean = np.array([[[124.55, 118.90, 102.94]]]) + self.std = np.array([[[ 56.77, 55.97, 57.50]]]) + print('\nParameters...') + for k, v in self.kwargs.items(): + print('%-10s: %s'%(k, v)) + + def __getattr__(self, name): + if name in self.kwargs: + return self.kwargs[name] + else: + return None + + +########################### Dataset Class ########################### +class Data(Dataset): + def __init__(self, cfg, model_name): + self.cfg = cfg + self.model_name = model_name + self.normalize = Normalize(mean=cfg.mean, std=cfg.std) + self.randomcrop = RandomCrop() + self.randomflip = RandomFlip() + + self.resize = Resize(384, 384) + + self.randomrotate = RandomRotate() + self.colorenhance = ColorEnhance() + self.gaussnoise = GaussNoise() + self.totensor = ToTensor() + + self.samples=os.listdir(cfg.datapath+'/Image') + + def __getitem__(self, idx): + name = self.samples[idx] + try: + image = cv2.imread(self.cfg.datapath+'/Image/'+name.replace('.jpg','')+'.jpg') + except: + print(str(name)+' not found!') + + + if self.cfg.mode=='train': + try: + mask = cv2.imread(self.cfg.datapath+'/GT/' +name.replace('.jpg','')+'.png', 0).astype(np.float32) + except: + print(str(name)+' not found!') + + image, mask = self.normalize(image, mask) + image, mask = self.randomcrop(image, mask) + image, mask = self.randomflip(image, mask) + + return image, mask + else: + shape = image.shape[:2] + image = self.normalize(image) + image = self.resize(image) + image = self.totensor(image) + + return image, shape, name + + def __len__(self): + return len(self.samples) + + def collate(self, batch): + size = 384 + image, mask = [list(item) for item in zip(*batch)] + for i in range(len(batch)): + image[i] = cv2.resize(image[i], dsize=(size, size), interpolation=cv2.INTER_LINEAR) + mask[i] = cv2.resize(mask[i], dsize=(size, size), interpolation=cv2.INTER_LINEAR) + + image = torch.from_numpy(np.stack(image, axis=0)).permute(0,3,1,2) + mask = torch.from_numpy(np.stack(mask, axis=0)).unsqueeze(1) + return image, mask diff --git a/evaltools/__pycache__/metrics.cpython-38.pyc b/evaltools/__pycache__/metrics.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..883ca5365f52f326e812d1bab615fac1b4fcdbfd GIT binary patch literal 14688 zcmd^G+mjsES?}xIXLfdXv|7m*#>S4lA(3oJ35gMo>_|>h*c-=E5@%v(GCSS7JFB_* zbg!hH^guv1BoG5hAc0_l7AjN}C<1v1JitFd@dQ*sb@4zI4=Ghq`~g%d=J)$fPw&j^ z%5iw$g_-Kp=X~dM&$)b;-}jv}A1{{+20o9uUu|zaX&C>(!t}>Q;R3GU`$#;aVR)uz zbWPQ&TBjf`n{P9y7EJGn+4X+f@{=e}Wk887>)(NfGrF_h@E*onnM%}^(2cG{#{W(8Wopw;T=m+g5>Q}ZqLA{NN`nKP?*6HrmHwH=# zuKQknQw@6c?Jyh$PnP*e5q1#60CMU*Tzw>##>&}}<&|kaPx`CwAzilb5yUHNBZKyJC|NY_9+k>8e zdUs>DAMT#+?KPD@90aG&Ja*=6^MkG4;2Dg`SErTV^;?18-0bvQ-P8WHRyQ0By20t5 zp7~S5y|~=OP=_t$H@#NaD&ggU@SvH)w`$&Nss+&Qm`K*_IMZuk&2Gebbgqfl?fNQ< zn!C@=5^6)Rsh*1(X6h!#YOUYlB(7Pih+FLR{8k@@MHCbtkx}ZT!Uo5iTYlJVd97jC zxrTMyR(`NO=z29&2C_7!I);23*UX2@bOBeegTyx)CZxs!|Jt7Q24u!_!1s>t?qqmD zLaanWvJ)Cw<{%+?uK>BprPRDoi(U!Tf?xE?cpGJfvQ@Rr z{NSv7{^; zLT1#A?5pnBiJZtfA-ZYoMy_XHwbUuk9%uUJBRAxnJNqUyQ#L7CQp!<=cbwx!l!X-M zaB)o}f1G+O6J@~m&bmzLZBt9JXbo2+Gq!`>UTpP;ak)K!)@t?#{gJN*!gE-Hxfxta zXE2rZ#^MIlnA)Sl^4cBh_r@X=R}1_Z`YQ0-B&kO*>exn`*oFwqakF{|wWn|;6J3%i z&Y2lZc*U%nqxzwVp5uy{>88G-S{9-GPh>dRH7Nhk+`+Vs&Ct9BO+!7iZ;x$Ig^f=b zk?py!lkLIE6Y7F7Mb8Kwb+-iuyuN(+yobT_#Y_`IIii@7#AHe#u zAW}ZhN$|gsZAZbLJ9EXwDoRKa#J83X&oQlyf-wkfTzr z`U(mmgcTPRq#Zi@nzNuRL#Rm5+|Z%QB%gstEKnyknVA}UswO*A~y$1>Gh~4srvD@BN*L+|Ya~PWncWblPu!d@_ zAgixvwPUxXRBKNwS+{xRQ_o-0(>$wSW5?^XLm0RR_#>33uD@hOU>phyRB|g<8F2bKuizf3C;M~zNg1c6EE5*cEjCa*H`2x^-(4- zFgdUgk|PRID@ayx2|@!p%hk#&rIjV2-yQj^~5H1nYQ^{IyDKKljw7_B7KnpW+t^tsoS!w(2B`YfuNhmJhX{ zm-1kIGBnQ6CvhcUZdAhiQzW5{b(@B5@}aHeC7PG&frf4xO574C0r&v80M1U#nl}Hs zLyk84St@w+Of~Qp^)Qm#?wodkB7>PBBKN1(Mf%NUo8-YledYQDj zd437B&*BP*#SCbsI$*}A$R1DT;UX$#mX^z!cAB6l?d-HZa%uZsHwD~7NSbH|v6X&N z)YNlKWRA&@!s$T#BGzm>-%$#|7Op_14{|zMU?-yM$) z{CLO%*KsA|sSsCNqa4R`Afz@LOg{gwQJs1W)1ZEq$p@IsF$b;GFX4K7sHeb5re1;L zbfj8IS>`0(nSc;0w8aB7Lx5#?E>tP-*Dj$2^a@va&7IOK!W@TG%2A9!mGWg=TBTT3 zs1){ZN~NUMy)>^^M28%uTwcTr-;V8Wq1}K#PQ?q`ovr0yxkZO8LzBtuinUPU4qNhBnRdMtQVDkN!J@)m&<81Hi$<@D)_s4+(UoW?>*)$e5N z&uLWsinj{ocEUS}(xP_{mh-+icR7iU`~YLb=ifg7fZ(>e!|k^Si_lxwutM}5VmTA- z2H+X=0_CB>#HcVo&%Ip0Z{SLdO2!;5PWtf{x{XYlGU@@=Ad^l47aDEs&xI}6KZLqO zTP6xk#}R^_^4UGnY0jKZi!D^oi5+#=H`#FwGpbg}9F^0qAFak*v`@`;7X?8SbRG~@ zLA=sZ4|;GG*!6uAirDEt13H{P_1*~;v@=xCh~C*m9{W}e5l&Jhk}1wCX>dKA&_#s6 zfukw7gc3_gs3R=XIDqRCX9gQP@Yfk#-J;V@zzC@pG{AmKuroII6d}O;WNd==EnETl z!7vLzG4!=c)}zRe7Sd^<;t6^)4m-^hFELY88XOf|U(82@IY5e$p$vxz^TN48gao6v z-~#EH9dYE|G;9iQ)~8tf{JrNQ;+tL^=ep~TMcVtD>ioUhEDLv{$jyy!8C#iQ_i4rZ?hj(Any>nm1zi z$iZI4TmKqYVy~CrLL)G9AgGh+oEa=#qm&HZ)f!FtiGLV z*jwS_OwKSl%j8@mBkS^bazB&Y&nEYC^YYli&;K2+#HFr48ueEPFH6>xOlo_`dW4=K6*4U|QvjWft0)`>P* z$-$c%=h&JpLW_`;8)pb=AfCBzMwx^6O|@{O1tOhk3paG3=}eDu0ansE35&Ri)}v%d zNF>pJVD;@5f?xE|bRfW0h=L-l+`<;r7Gjl%hfSiO#GkRb9h=vumbWO8PA%q*EcV5k z9WQ8P+k-AM^Gg!q6@QN_S&{;x8(0>r0JTvtAH;V`o+d2UrjSlPHDppnS%W5zELjKy z-k}gPP?1?GlE`_Dsr{O@3Hcc#+Nf!evrIzEN5J`!JJYT(p-r_baxv!KSDMhK+LfWQ z=igVE(5BigAn#BHC2HwFB5kU6DKBCS@I9#Nq`U;RS4DXd*Ah(Oa%e(k5igC3n~p?0 z5&cB28ZAak(efr)f^iUpyAO<M+gh+#%@S~L+O$NFVxewQg^ z(8aPfnnN3y-6(sgCNDMejII~qoIF?XzF+&ybt?mM0Xt5C+8L&Qb8X#wvL777c%49a33#5+XPiTKwl z^J|~3cx8>sz# zT%uBeIGtK{rMQwsjHtX+UOB$9qS%66oB2>6F5n8@K$1FhNf>q#aP@4sRtT}fL7EQ2 zX5O&9td|2dH)$y$ThjH0bUvkL5wtD7;UqDnLP}?dm+kfDbWtDRXq0g9W9bdkTlQ*r zd&N5j#$JhYmlCIJCj3VYXW~1n4~i<*ufsNKexw>vidZhoL}_%lV$dHu0JYdz;~kg|I>BfX(q{<$fz`OtoW(v_veA4? z_B}vn)VKgTcHeZ=nA*O?K^&<3qJ)q&)~92LU52oU{U&ZBbwwe>Sx|z`$d4E4j!bP` z_&xO+-}W${o85ax)vKT$a){)Csm*1yn^Ojl;;{-Rl)9LkQYCW*I-K88Ii1?O`-_}f zS?4-3oYr*jIN6(rR{+zO_;|@~uUbQ6xbOLNs}`p37-;b^j@j>N5PMKN@93#_qp|vR zCUfKYB511|54DnkQ6OlunvR2SPxp!aSx=7)k}?Uo7E3N{B1;bAJ=vAGXS$Di>oMJn zgwy^29zB5QddcQG`2{q&kDcW2v6G{dGhLinbi9(2tg*sxFee~b9_8cs(7dEt_=hNf zQ)S~`95?{l3xvJyq_hD2LW@dV21lH`x?(ed#t@$u!2IC|K2HCeEVC9?t-)R6{+Cb! zmI5}y{xL#r4xlzXpe@+O&25QQ<;19vNF3@{I>QJ5_QBwA7*jS-4 z2WXrXm$4dE36x3mTwt-qs0686j4HI*rt!+JV1Kc6(}WL}XWu;FvgFC7XbEa5qcPTP zCoV#=Cc-5JZ0Cxk#bv=oX`L=kY7fk{`ejV!WD}C|s(y{huOf*b*z)__`t_UiqXPZn zCQh00(1DyKi}QZ6Jq)-$!qzm6yNIvA`nBbGUg7FCDspH63c!#%b95_5nrBEzp|F;%s(K%( zf`Z@A9Jhkh&oMdABB8I*xgHblSaFu2tI@&cHqd)2xW_m*Vy7YCa*FPNAES}K8gfU@CTSDu$wuxbJ(=d z`3pHy9EE{{R43v7P0dJpbY~sxaT!hC&XzyHqarv{v|NfJCmKh)eEi8Nsg$K7YF0Oq z}W&*1g2)qQ`3ml$0F)+1{rcc?GbZ8oa zg44mOj_0OkN+#(+&_2gW(wG;X+ll#}yoCqHhj>YEoY&7ED_Th1dFo{Z4|sB2&WH-0$LTE-Si?cUn`cO&sU!C)n3Q*% ziDCpTdEP4^FK0!)BF>5;2+p`Dw*)~)^mgAw`vUO{V`-ovdW+*pycf{}#K^drME#%A zCL2s;kK-fqm~K}A@&Wb%hA9b#fzN;upiG}2S`hStGIMDgcoE?DlJ*6zz}c+m|7Bpq z0|x~CyY7Uv`04V+!Q)UYn$_6xW@)yO^fk@e=K=B1}%^Jx&S zH*jzj$urMikIffi^ZD5P*gV^Q3%~FMvaMJhtL9V?!+jUhs=Z=A1}J{7xn!@P^pG`L zJ;ciNhA9i*W@o>{&e}S%`DY|>dS`Gt`8%~d1W61Syb5W)Xxuat4rl1VF7l|GecX5r zy8*}}7Q;OG7_G3Y0E=aTqZzlwmI_zCh<`D0bC#EE%w#2%|0S_l1@-5iy7KH(Pd^f*uh3P5_b3iNb-p{mvIkEhGkMR{Z{v;piv#U_?Ba-a z$5Y>vJpNUL|4z{j0Ot<_eN<3A8x!&?B{gj&lDZbDY*6gM0oM zO-Ae=&7B0e^D7UoR21oeebSeeINNM`gEkI2yNE9g_k=AP<)Gae?wx`v?zOgjtr2bh z;&+*d-rHwR=IfKpy~t#NiHv~mz7El^F!w1YEhZaGZZcsvrlvH-YjN(pK737gN{f`^ Zp!?YG$k;2n6{q@S^+ffZ)nWCO{{iR712q5u literal 0 HcmV?d00001 diff --git a/evaltools/eval.py b/evaltools/eval.py new file mode 100644 index 0000000..fee5061 --- /dev/null +++ b/evaltools/eval.py @@ -0,0 +1,162 @@ +import os +import sys +import cv2 +from tqdm import tqdm +import metrics +import json +import argparse +import numpy as np + +def Borders_Capture(gt,pred,dksize=15): + gray = cv2.cvtColor(gt, cv2.COLOR_BGR2GRAY) + ret, binary = cv2.threshold(gray, 127, 255, cv2.THRESH_BINARY) + + contours, hierarchy = cv2.findContours(binary, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) + img=gt.copy() + img[:]=0 + cv2.drawContours(img, contours, -1, (255, 255, 255), 3) + kernel = np.ones((dksize, dksize), np.uint8) + img_dilate = cv2.dilate(img, kernel) + + res = cv2.bitwise_and(img_dilate, gt) + b, g, r = cv2.split(res) + alpha = np.rollaxis(img_dilate, 2, 0)[0] + merge = cv2.merge((b, g, r, alpha)) + + resp = cv2.bitwise_and(img_dilate, pred) + b, g, r = cv2.split(resp) + alpha = np.rollaxis(img_dilate, 2, 0)[0] + mergep = cv2.merge((b, g, r, alpha)) + + merge = cv2.cvtColor(merge, cv2.COLOR_RGB2GRAY) + mergep = cv2.cvtColor(mergep, cv2.COLOR_RGB2GRAY) + return merge,mergep,np.sum(img_dilate)/255 + +def eval(parser, dataset): + args = parser.parse_args() + + FM = metrics.Fmeasure_and_FNR() + WFM = metrics.WeightedFmeasure() + SM = metrics.Smeasure() + EM = metrics.Emeasure() + MAE = metrics.MAE() + + BR_MAE = metrics.MAE() + BR_wF = metrics.WeightedFmeasure() + + model = args.model + gt_root = args.GT_root + pred_root = args.pred_root + + gt_root = os.path.join(gt_root, dataset) + gt_root = os.path.join(gt_root, 'GT') + pred_root = os.path.join(pred_root, dataset) + + gt_name_list = sorted(os.listdir(pred_root)) + + for gt_name in tqdm(gt_name_list, total=len(gt_name_list)): + gt_path = os.path.join(gt_root, gt_name) + pred_path = os.path.join(pred_root, gt_name) + gt = cv2.imread(gt_path, cv2.IMREAD_GRAYSCALE) + gt_width, gt_height = gt.shape + pred = cv2.imread(pred_path, cv2.IMREAD_GRAYSCALE) + pred_width, pred_height = pred.shape + if gt.shape != pred.shape: + cv2.imwrite( os.path.join(pred_root, gt_name), cv2.resize(pred, gt.shape[::-1])) + pred = cv2.imread(pred_path, cv2.IMREAD_GRAYSCALE) + + FM.step(pred=pred, gt=gt) + WFM.step(pred=pred, gt=gt) + SM.step(pred=pred, gt=gt) + EM.step(pred=pred, gt=gt) + MAE.step(pred=pred, gt=gt) + + if args.BR == 'on': + BR_gt, BR_pred, area=Borders_Capture(cv2.imread(gt_path), cv2.imread(pred_path), int(args.br_rate)) + BR_MAE.step(pred=BR_pred, gt=BR_gt,area=area) + BR_wF.step(pred=BR_pred, gt=BR_gt) + + fm = FM.get_results()[0]['fm'] + wfm = WFM.get_results()['wfm'] + sm = SM.get_results()['sm'] + em = EM.get_results()['em'] + mae = MAE.get_results()['mae'] + fnr = FM.get_results()[1] + if args.BR == 'on': + BRmae= BR_MAE.get_results()['mae'] + BRmae_r = str(BRmae.round(3)) + BRwF = BR_wF.get_results()['wfm'] + BRwF_r = str(BRwF.round(3)) + model_r = str(args.model) + Smeasure_r = str(sm.round(3)) + Wmeasure_r = str(wfm.round(3)) + MAE_r = str(mae.round(3)) + adpEm_r = str(em['adp'].round(3)) + meanEm_r = str('-' if em['curve'] is None else em['curve'].mean().round(3)) + maxEm_r = str('-' if em['curve'] is None else em['curve'].max().round(3)) + adpFm_r = str(fm['adp'].round(3)) + meanFm_r = str(fm['curve'].mean().round(3)) + maxFm_r = str(fm['curve'].max().round(3)) + fnr_r = str(fnr.round(3)) + + if args.BR == 'on': + eval_record = str( + 'Model:'+ model_r + ','+ + 'Dataset:'+ dataset + '||'+ + 'Smeasure:'+ Smeasure_r + '; '+ + 'meanEm:'+ meanEm_r + '; '+ + 'wFmeasure:'+ Wmeasure_r + '; '+ + 'MAE:'+ MAE_r + '; '+ + 'fnr:'+ fnr_r + ';' + + 'adpEm:'+ adpEm_r + '; '+ + 'meanEm:'+ meanEm_r + '; '+ + 'maxEm:'+ maxEm_r + '; '+ + 'adpFm:'+ adpFm_r + '; '+ + 'meanFm:'+ meanFm_r + '; '+ + 'maxFm:'+ maxFm_r+ ';' + + 'BR'+str(args.br_rate)+'_mae:' + BRmae_r + ';' + + 'BR'+str(args.br_rate)+'_wF:' + BRwF_r + ) + else: + eval_record = str( + 'Model:'+ model_r + ','+ + 'Dataset:'+ dataset + '||'+ + 'Smeasure:'+ Smeasure_r + '; '+ + 'meanEm:'+ meanEm_r + '; '+ + 'wFmeasure:'+ Wmeasure_r + '; '+ + 'MAE:'+ MAE_r + '; '+ + 'fnr:'+ fnr_r + ';' + + 'adpEm:'+ adpEm_r + '; '+ + 'meanEm:'+ meanEm_r + '; '+ + 'maxEm:'+ maxEm_r + '; '+ + 'adpFm:'+ adpFm_r + '; '+ + 'meanFm:'+ meanFm_r + '; '+ + 'maxFm:'+ maxFm_r + ) + + + print(eval_record) + print('#'*50) + if args.record_path is not None: + txt = args.record_path + else: + txt = 'output/eval_record.txt' + f = open(txt, 'a') + f.write(eval_record) + f.write("\n") + f.close() + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument("--model", default='CamoFormer') + parser.add_argument("--pred_root", default='Prediction/CamoFormer') + parser.add_argument("--GT_root", default='Dataset/TestData') + parser.add_argument("--record_path", default=None) + parser.add_argument("--BR", default='off') + parser.add_argument("--br_rate", default=15) + args = parser.parse_args() + datasets = ['NC4K', 'COD10K', 'CAMO', 'CHAMELEON'] + existed_pred = os.listdir(args.pred_root) + for dataset in datasets: + if dataset in existed_pred: + eval(parser, dataset) diff --git a/evaltools/metrics.py b/evaltools/metrics.py new file mode 100644 index 0000000..7b14faf --- /dev/null +++ b/evaltools/metrics.py @@ -0,0 +1,442 @@ +""" +Code: Metrics +Desc: This code heavily borrowed from https://github.com/lartpang with slight modifications. +""" +import numpy as np +from scipy.ndimage import convolve, distance_transform_edt as bwdist + +_EPS = 1e-16 +_TYPE = np.float64 + + +def _prepare_data(pred: np.ndarray, gt: np.ndarray) -> tuple: + gt = gt > 128 + pred = pred / 255 + if pred.max() != pred.min(): + pred = (pred - pred.min()) / (pred.max() - pred.min()) + return pred, gt + + +def _get_adaptive_threshold(matrix: np.ndarray, max_value: float = 1) -> float: + return min(2 * matrix.mean(), max_value) + + +class Fmeasure_and_FNR(object): + def __init__(self, beta: float = 0.3): + self.beta = beta + self.precisions = [] + self.recalls = [] + self.fnrs = [] + self.adaptive_fms = [] + self.changeable_fms = [] + + def step(self, pred: np.ndarray, gt: np.ndarray): + pred, gt = _prepare_data(pred, gt) + + adaptive_fm = self.cal_adaptive_fm(pred=pred, gt=gt) + self.adaptive_fms.append(adaptive_fm) + + precisions, recalls, changeable_fms = self.cal_pr(pred=pred, gt=gt) + fnrs = 1 - recalls + self.precisions.append(precisions) + self.recalls.append(recalls) + self.fnrs.append(fnrs) + self.changeable_fms.append(changeable_fms) + + def cal_adaptive_fm(self, pred: np.ndarray, gt: np.ndarray) -> float: + adaptive_threshold = _get_adaptive_threshold(pred, max_value=1) + binary_predcition = pred >= adaptive_threshold + area_intersection = binary_predcition[gt].sum() + if area_intersection == 0: + adaptive_fm = 0 + else: + pre = area_intersection / np.count_nonzero(binary_predcition) + rec = area_intersection / np.count_nonzero(gt) + adaptive_fm = (1 + self.beta) * pre * rec / (self.beta * pre + rec) + return adaptive_fm + + def cal_pr(self, pred: np.ndarray, gt: np.ndarray) -> tuple: + pred = (pred * 255).astype(np.uint8) + bins = np.linspace(0, 256, 257) + fg_hist, _ = np.histogram(pred[gt], bins=bins) + bg_hist, _ = np.histogram(pred[~gt], bins=bins) + fg_w_thrs = np.cumsum(np.flip(fg_hist), axis=0) + bg_w_thrs = np.cumsum(np.flip(bg_hist), axis=0) + TPs = fg_w_thrs + Ps = fg_w_thrs + bg_w_thrs + Ps[Ps == 0] = 1 + T = max(np.count_nonzero(gt), 1) + precisions = TPs / Ps + recalls = TPs / T + numerator = (1 + self.beta) * precisions * recalls + denominator = np.where(numerator == 0, 1, self.beta * precisions + recalls) + changeable_fms = numerator / denominator + return precisions, recalls, changeable_fms + + def get_results(self) -> dict: + adaptive_fm = np.mean(np.array(self.adaptive_fms, _TYPE)) + changeable_fm = np.mean(np.array(self.changeable_fms, dtype=_TYPE), axis=0) + precision = np.mean(np.array(self.precisions, dtype=_TYPE), axis=0) # N, 256 + recall = np.mean(np.array(self.recalls, dtype=_TYPE), axis=0) # N, 256 + fnr = np.mean(self.fnrs, dtype=_TYPE) + return dict(fm=dict(adp=adaptive_fm, curve=changeable_fm), + pr=dict(p=precision, r=recall)), fnr + + +class MAE(object): + def __init__(self): + self.maes = [] + + def step(self, pred: np.ndarray, gt: np.ndarray, area = None): + pred, gt = _prepare_data(pred, gt) + + mae = self.cal_mae(pred, gt, area) + self.maes.append(mae) + + def cal_mae(self, pred: np.ndarray, gt: np.ndarray, area) -> float: + if area is not None: + mae = np.sum(np.abs(pred - gt))/np.sum(area) + else: + mae = np.mean(np.abs(pred - gt)) + return mae + + def get_results(self) -> dict: + mae = np.mean(np.array(self.maes, _TYPE)) + return dict(mae=mae) + + + + + +class FNR(object): + def __init__(self, beta: float = 0.3): + self.beta = beta + self.precisions = [] + self.recalls = [] + self.fnrs = [] + + + def step(self, pred: np.ndarray, gt: np.ndarray): + pred, gt = _prepare_data(pred, gt) + + precisions, recalls, changeable_fms = self.cal_pr(pred=pred, gt=gt) + fnr = 1 - recalls + + self.fnrs.append(fnr) + + def cal_pr(self, pred: np.ndarray, gt: np.ndarray) -> tuple: + pred = (pred * 255).astype(np.uint8) + bins = np.linspace(0, 256, 257) + fg_hist, _ = np.histogram(pred[gt], bins=bins) + bg_hist, _ = np.histogram(pred[~gt], bins=bins) + fg_w_thrs = np.cumsum(np.flip(fg_hist), axis=0) + bg_w_thrs = np.cumsum(np.flip(bg_hist), axis=0) + TPs = fg_w_thrs + Ps = fg_w_thrs + bg_w_thrs + Ps[Ps == 0] = 1 + T = max(np.count_nonzero(gt), 1) + precisions = TPs / Ps + recalls = TPs / T + numerator = (1 + self.beta) * precisions * recalls + denominator = np.where(numerator == 0, 1, self.beta * precisions + recalls) + changeable_fms = numerator / denominator + return precisions, recalls, changeable_fms + + def get_results(self) -> dict: + #fnr = np.mean(np.array(self.fnrs, dtype=_TYPE), axis=0) # N, 256 + fnr = np.mean(self.fnrs, dtype=_TYPE) + return dict(fnr=fnr) + + + + +class Smeasure(object): + def __init__(self, alpha: float = 0.5): + self.sms = [] + self.alpha = alpha + + def step(self, pred: np.ndarray, gt: np.ndarray): + pred, gt = _prepare_data(pred=pred, gt=gt) + + sm = self.cal_sm(pred, gt) + self.sms.append(sm) + #return sm + + def cal_sm(self, pred: np.ndarray, gt: np.ndarray) -> float: + y = np.mean(gt) + if y == 0: + sm = 1 - np.mean(pred) + elif y == 1: + sm = np.mean(pred) + else: + sm = self.alpha * self.object(pred, gt) + (1 - self.alpha) * self.region(pred, gt) + sm = max(0, sm) + return sm + + def object(self, pred: np.ndarray, gt: np.ndarray) -> float: + fg = pred * gt + bg = (1 - pred) * (1 - gt) + u = np.mean(gt) + object_score = u * self.s_object(fg, gt) + (1 - u) * self.s_object(bg, 1 - gt) + return object_score + + def s_object(self, pred: np.ndarray, gt: np.ndarray) -> float: + x = np.mean(pred[gt == 1]) + sigma_x = np.std(pred[gt == 1]) + score = 2 * x / (np.power(x, 2) + 1 + sigma_x + _EPS) + return score + + def region(self, pred: np.ndarray, gt: np.ndarray) -> float: + x, y = self.centroid(gt) + part_info = self.divide_with_xy(pred, gt, x, y) + w1, w2, w3, w4 = part_info['weight'] + pred1, pred2, pred3, pred4 = part_info['pred'] + gt1, gt2, gt3, gt4 = part_info['gt'] + score1 = self.ssim(pred1, gt1) + score2 = self.ssim(pred2, gt2) + score3 = self.ssim(pred3, gt3) + score4 = self.ssim(pred4, gt4) + + return w1 * score1 + w2 * score2 + w3 * score3 + w4 * score4 + + def centroid(self, matrix: np.ndarray) -> tuple: + h, w = matrix.shape + if matrix.sum() == 0: + x = np.round(w / 2) + y = np.round(h / 2) + else: + area_object = np.sum(matrix) + row_ids = np.arange(h) + col_ids = np.arange(w) + x = np.round(np.sum(np.sum(matrix, axis=0) * col_ids) / area_object) + y = np.round(np.sum(np.sum(matrix, axis=1) * row_ids) / area_object) + return int(x) + 1, int(y) + 1 + + def divide_with_xy(self, pred: np.ndarray, gt: np.ndarray, x, y) -> dict: + h, w = gt.shape + area = h * w + + gt_LT = gt[0:y, 0:x] + gt_RT = gt[0:y, x:w] + gt_LB = gt[y:h, 0:x] + gt_RB = gt[y:h, x:w] + + pred_LT = pred[0:y, 0:x] + pred_RT = pred[0:y, x:w] + pred_LB = pred[y:h, 0:x] + pred_RB = pred[y:h, x:w] + + w1 = x * y / area + w2 = y * (w - x) / area + w3 = (h - y) * x / area + w4 = 1 - w1 - w2 - w3 + + return dict(gt=(gt_LT, gt_RT, gt_LB, gt_RB), + pred=(pred_LT, pred_RT, pred_LB, pred_RB), + weight=(w1, w2, w3, w4)) + + def ssim(self, pred: np.ndarray, gt: np.ndarray) -> float: + h, w = pred.shape + N = h * w + + x = np.mean(pred) + y = np.mean(gt) + + sigma_x = np.sum((pred - x) ** 2) / (N - 1) + sigma_y = np.sum((gt - y) ** 2) / (N - 1) + sigma_xy = np.sum((pred - x) * (gt - y)) / (N - 1) + + alpha = 4 * x * y * sigma_xy + beta = (x ** 2 + y ** 2) * (sigma_x + sigma_y) + + if alpha != 0: + score = alpha / (beta + _EPS) + elif alpha == 0 and beta == 0: + score = 1 + else: + score = 0 + return score + + def get_results(self) -> dict: + sm = np.mean(np.array(self.sms, dtype=_TYPE)) + return dict(sm=sm) + + +class Emeasure(object): + def __init__(self): + self.adaptive_ems = [] + self.changeable_ems = [] + + def step(self, pred: np.ndarray, gt: np.ndarray): + pred, gt = _prepare_data(pred=pred, gt=gt) + self.gt_fg_numel = np.count_nonzero(gt) + self.gt_size = gt.shape[0] * gt.shape[1] + + changeable_ems = self.cal_changeable_em(pred, gt) + self.changeable_ems.append(changeable_ems) + adaptive_em = self.cal_adaptive_em(pred, gt) + self.adaptive_ems.append(adaptive_em) + + def cal_adaptive_em(self, pred: np.ndarray, gt: np.ndarray) -> float: + adaptive_threshold = _get_adaptive_threshold(pred, max_value=1) + adaptive_em = self.cal_em_with_threshold(pred, gt, threshold=adaptive_threshold) + return adaptive_em + + def cal_changeable_em(self, pred: np.ndarray, gt: np.ndarray) -> np.ndarray: + changeable_ems = self.cal_em_with_cumsumhistogram(pred, gt) + return changeable_ems + + def cal_em_with_threshold(self, pred: np.ndarray, gt: np.ndarray, threshold: float) -> float: + binarized_pred = pred >= threshold + fg_fg_numel = np.count_nonzero(binarized_pred & gt) + fg_bg_numel = np.count_nonzero(binarized_pred & ~gt) + + fg___numel = fg_fg_numel + fg_bg_numel + bg___numel = self.gt_size - fg___numel + + if self.gt_fg_numel == 0: + enhanced_matrix_sum = bg___numel + elif self.gt_fg_numel == self.gt_size: + enhanced_matrix_sum = fg___numel + else: + parts_numel, combinations = self.generate_parts_numel_combinations( + fg_fg_numel=fg_fg_numel, fg_bg_numel=fg_bg_numel, + pred_fg_numel=fg___numel, pred_bg_numel=bg___numel, + ) + + results_parts = [] + for i, (part_numel, combination) in enumerate(zip(parts_numel, combinations)): + align_matrix_value = 2 * (combination[0] * combination[1]) / \ + (combination[0] ** 2 + combination[1] ** 2 + _EPS) + enhanced_matrix_value = (align_matrix_value + 1) ** 2 / 4 + results_parts.append(enhanced_matrix_value * part_numel) + enhanced_matrix_sum = sum(results_parts) + + em = enhanced_matrix_sum / (self.gt_size - 1 + _EPS) + return em + + def cal_em_with_cumsumhistogram(self, pred: np.ndarray, gt: np.ndarray) -> np.ndarray: + pred = (pred * 255).astype(np.uint8) + bins = np.linspace(0, 256, 257) + fg_fg_hist, _ = np.histogram(pred[gt], bins=bins) + fg_bg_hist, _ = np.histogram(pred[~gt], bins=bins) + fg_fg_numel_w_thrs = np.cumsum(np.flip(fg_fg_hist), axis=0) + fg_bg_numel_w_thrs = np.cumsum(np.flip(fg_bg_hist), axis=0) + + fg___numel_w_thrs = fg_fg_numel_w_thrs + fg_bg_numel_w_thrs + bg___numel_w_thrs = self.gt_size - fg___numel_w_thrs + + if self.gt_fg_numel == 0: + enhanced_matrix_sum = bg___numel_w_thrs + elif self.gt_fg_numel == self.gt_size: + enhanced_matrix_sum = fg___numel_w_thrs + else: + parts_numel_w_thrs, combinations = self.generate_parts_numel_combinations( + fg_fg_numel=fg_fg_numel_w_thrs, fg_bg_numel=fg_bg_numel_w_thrs, + pred_fg_numel=fg___numel_w_thrs, pred_bg_numel=bg___numel_w_thrs, + ) + + results_parts = np.empty(shape=(4, 256), dtype=np.float64) + for i, (part_numel, combination) in enumerate(zip(parts_numel_w_thrs, combinations)): + align_matrix_value = 2 * (combination[0] * combination[1]) / \ + (combination[0] ** 2 + combination[1] ** 2 + _EPS) + enhanced_matrix_value = (align_matrix_value + 1) ** 2 / 4 + results_parts[i] = enhanced_matrix_value * part_numel + enhanced_matrix_sum = results_parts.sum(axis=0) + + em = enhanced_matrix_sum / (self.gt_size - 1 + _EPS) + return em + + def generate_parts_numel_combinations(self, fg_fg_numel, fg_bg_numel, pred_fg_numel, pred_bg_numel): + bg_fg_numel = self.gt_fg_numel - fg_fg_numel + bg_bg_numel = pred_bg_numel - bg_fg_numel + + parts_numel = [fg_fg_numel, fg_bg_numel, bg_fg_numel, bg_bg_numel] + + mean_pred_value = pred_fg_numel / self.gt_size + mean_gt_value = self.gt_fg_numel / self.gt_size + + demeaned_pred_fg_value = 1 - mean_pred_value + demeaned_pred_bg_value = 0 - mean_pred_value + demeaned_gt_fg_value = 1 - mean_gt_value + demeaned_gt_bg_value = 0 - mean_gt_value + + combinations = [ + (demeaned_pred_fg_value, demeaned_gt_fg_value), + (demeaned_pred_fg_value, demeaned_gt_bg_value), + (demeaned_pred_bg_value, demeaned_gt_fg_value), + (demeaned_pred_bg_value, demeaned_gt_bg_value) + ] + return parts_numel, combinations + + def get_results(self) -> dict: + adaptive_em = np.mean(np.array(self.adaptive_ems, dtype=_TYPE)) + changeable_em = np.mean(np.array(self.changeable_ems, dtype=_TYPE), axis=0) + return dict(em=dict(adp=adaptive_em, curve=changeable_em)) + + +class WeightedFmeasure(object): + def __init__(self, beta: float = 1): + self.beta = beta + self.weighted_fms = [] + + def step(self, pred: np.ndarray, gt: np.ndarray): + pred, gt = _prepare_data(pred=pred, gt=gt) + + if np.all(~gt): + wfm = 0 + else: + wfm = self.cal_wfm(pred, gt) + self.weighted_fms.append(wfm) + + def cal_wfm(self, pred: np.ndarray, gt: np.ndarray) -> float: + # [Dst,IDXT] = bwdist(dGT); + Dst, Idxt = bwdist(gt == 0, return_indices=True) + + # %Pixel dependency + # E = abs(FG-dGT); + E = np.abs(pred - gt) + Et = np.copy(E) + Et[gt == 0] = Et[Idxt[0][gt == 0], Idxt[1][gt == 0]] + + # K = fspecial('gaussian',7,5); + # EA = imfilter(Et,K); + K = self.matlab_style_gauss2D((7, 7), sigma=5) + EA = convolve(Et, weights=K, mode="constant", cval=0) + # MIN_E_EA = E; + # MIN_E_EA(GT & EA np.ndarray: + """ + 2D gaussian mask - should give the same result as MATLAB's + fspecial('gaussian',[shape],[sigma]) + """ + m, n = [(ss - 1) / 2 for ss in shape] + y, x = np.ogrid[-m: m + 1, -n: n + 1] + h = np.exp(-(x * x + y * y) / (2 * sigma * sigma)) + h[h < np.finfo(h.dtype).eps * h.max()] = 0 + sumh = h.sum() + if sumh != 0: + h /= sumh + return h + + def get_results(self) -> dict: + weighted_fm = np.mean(np.array(self.weighted_fms, dtype=_TYPE)) + return dict(wfm=weighted_fm) diff --git a/evaltools/valid_eval.sh b/evaltools/valid_eval.sh new file mode 100644 index 0000000..cdbf346 --- /dev/null +++ b/evaltools/valid_eval.sh @@ -0,0 +1,8 @@ +pred_path=${1} +gt_path=${2} +name=${3} + +python evaltools/eval.py \ + --pred_root $pred_path \ + --GT_root $gt_path \ + --model $name \ No newline at end of file diff --git a/main.py b/main.py new file mode 100644 index 0000000..8cb60bf --- /dev/null +++ b/main.py @@ -0,0 +1,184 @@ +import os +import sys +import datetime +import math +import dataset +import argparse +import cv2 +import torch +import numpy as np +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.data import DataLoader +from torch.utils.tensorboard import SummaryWriter +from torch.autograd import Variable +from torch.optim import lr_scheduler +from apex import amp +from model.CamoFormer import CamoFormer +import matplotlib.pyplot as plt +plt.ion() + +sys.path.insert(0, '../') +sys.dont_write_bytecode = True +os.environ["CUDA_VISIBLE_DEVICES"] = '1' + + + +def loss_cal(pred, mask): + pred = torch.sigmoid(pred) + inter = (pred*mask).sum(dim=(2,3)) + union = (pred+mask).sum(dim=(2,3)) + iou = 1-(inter+1)/(union-inter+1) + return iou.mean() + + +def train(Dataset, parser): + + args = parser.parse_args() + _MODEL_ = args.model + _DATASET_ = args.dataset + _TESTDATASET_ = args.test_dataset + _LR_ = args.lr + _DECAY_ = args.decay + _MOMEN_ = args.momen + _BATCHSIZE_ = args.batchsize + _EPOCH_ = args.epoch + _SAVEPATH_ = args.savepath + _VALID_ = args.valid + _WEIGHT_ = args.weight + _PRETRAINPATH_ = args.pretrain_path + + print(args) + + cfg = Dataset.Config(datapath=_DATASET_, savepath=_SAVEPATH_, mode='train', batch=_BATCHSIZE_, lr=_LR_, momen=_MOMEN_, decay=_DECAY_, epoch=_EPOCH_) + + data = Dataset.Data(cfg, _MODEL_) + loader = DataLoader(data, collate_fn=data.collate, batch_size=cfg.batch, shuffle=True, pin_memory=True, num_workers=6) + ## network + net = CamoFormer(cfg, _PRETRAINPATH_) + net = net.cuda() + + net.train(True) + net.cuda() + ## parameter + base, head = [], [] + + for name, param in net.named_parameters(): + if 'encoder.conv1' in name or 'encoder.bn1' in name: + pass + elif 'encoder' in name: + base.append(param) + elif 'network' in name: + base.append(param) + else: + head.append(param) + + optimizer = torch.optim.SGD([{'params':base}, {'params':head}], lr=cfg.lr, momentum=cfg.momen, weight_decay=cfg.decay, nesterov=True) + + net, optimizer = amp.initialize(net, optimizer, opt_level='O1') + sw = SummaryWriter(cfg.savepath) + global_step = 0 + + for epoch in range(cfg.epoch): + optimizer.param_groups[0]['lr'] = 0.5*(1 + math.cos(math.pi * (epoch ) / (cfg.epoch)))*cfg.lr*0.1 + optimizer.param_groups[1]['lr'] = 0.5*(1 + math.cos(math.pi * (epoch ) / (cfg.epoch)))*cfg.lr + + for step, (image, mask) in enumerate(loader): + image, mask = image.cuda(), mask.cuda() + P5, P4, P3, P2, P1 = net(image) + + loss5 = loss_cal(P5, mask)+F.binary_cross_entropy_with_logits(P5,mask) + loss4 = loss_cal(P4, mask)+F.binary_cross_entropy_with_logits(P4,mask) + loss3 = loss_cal(P3, mask)+F.binary_cross_entropy_with_logits(P3,mask) + loss2 = loss_cal(P2, mask)+F.binary_cross_entropy_with_logits(P2,mask) + loss1 = loss_cal(P1, mask)+F.binary_cross_entropy_with_logits(P1,mask) + + loss = _WEIGHT_[0]*loss5 + _WEIGHT_[1]*loss4 + _WEIGHT_[2]*loss3 + _WEIGHT_[3]*loss2 + _WEIGHT_[4]*loss1 + + optimizer.zero_grad() + with amp.scale_loss(loss, optimizer) as scale_loss: + scale_loss.backward() + optimizer.step() + torch.nn.utils.clip_grad_norm_(net.parameters(), 1.0) + + global_step += 1 + if step%10 == 0: + print('%s | step:%d/%d/%d | lr=%.6f | loss=%.6f | loss1=%.6f | loss2=%.6f | loss3=%.6f | loss4=%.6f | loss5=%.6f |' + %(datetime.datetime.now(), global_step, epoch+1, cfg.epoch, optimizer.param_groups[0]['lr'], loss.item(),loss1.item(),loss2.item(),loss3.item(),loss4.item(),loss5.item())) + + + if epoch %5 == 0: + torch.save(net.state_dict(), cfg.savepath+'/'+_MODEL_+str(epoch+1)) + # 'CHAMELEON','COD10K','NC4K','CAMO' + for path in ['COD10K']: + path=_TESTDATASET_+'/'+path + t = Valid(dataset, path, epoch, 'CamoFormer','output/checkpoint/CamoFormer/CamoFormer/' ) + t.save() + os.system('bash evaltools/valid_eval.sh '+'output/Prediction/CamoFormer-epoch'+str(epoch+1)+' '+_TESTDATASET_) + +def test(dataset,parser): + args = parser.parse_args() + _TESTDATASET_ = args.test_dataset + _CKPT_ = args.ckpt + + + for path in ['CHAMELEON','COD10K','NC4K','CAMO']: + path=_TESTDATASET_+'/'+path + t = Valid(dataset, path, 0, 'CamoFormer', _CKPT_, mode='test') + t.save() + + +class Valid(object): + def __init__(self, Dataset, Path, epoch, model_name, checkpoint_path, mode='Valid'): + ## dataset + if mode == 'test': + self.cfg = Dataset.Config(datapath=Path, snapshot=checkpoint_path, mode='test') + else: + self.cfg = Dataset.Config(datapath=Path, snapshot=checkpoint_path+model_name+str(epoch+1), mode='test') + self.mode = mode + self.data = Dataset.Data(self.cfg, model_name) + self.loader = DataLoader(self.data, batch_size=1, shuffle=False, num_workers=8) + ## network + self.net = CamoFormer(self.cfg) + self.net.train(False) + self.net.cuda() + self.epoch = epoch + + def save(self): + with torch.no_grad(): + for image, (H, W), name in self.loader: + image, shape = image.cuda().float(), (H, W) + P5, P4, P3, P2, P1 = self.net(image, shape) + pred = torch.sigmoid(P1[0,0]).cpu().numpy()*255 + if self.mode == 'test': + head = 'output/Prediction/CamoFormer-test'+'/'+ self.cfg.datapath.split('/')[-1] + else: + head = 'output/Prediction/CamoFormer-epoch'+str(self.epoch+1)+'/'+ self.cfg.datapath.split('/')[-1] + if not os.path.exists(head): + os.makedirs(head) + cv2.imwrite(head+'/'+name[0].replace('.jpg','')+'.png', np.round(pred)) + + + +if __name__=='__main__': + parser = argparse.ArgumentParser() + parser.add_argument("--model", default='CamoFormer') + parser.add_argument("--dataset", default='dataset/TrainDataset') + parser.add_argument("--test_dataset", default='dataset/TestDataset') + parser.add_argument("--lr", type=float, default=0.05) + parser.add_argument("--momen", type=float, default=0.9) + parser.add_argument("--decay", type=float, default=1e-4) + parser.add_argument("--batchsize", type=int, default=14) + parser.add_argument("--epoch", type=int, default=60) + parser.add_argument("--savepath", default='output/checkpoint/CamoFormer/CamoFormer') + parser.add_argument("--weight", default=[0.5,0.5,0.8,1.0,2.0]) + parser.add_argument("--valid", default=True) + parser.add_argument("--mode", default='train') + parser.add_argument("--ckpt", default='CamoFormer.pth') + parser.add_argument("--pretrain_path",default=None) + args = parser.parse_args() + if args.mode == 'train': + train(dataset, parser) + else: + test(dataset, parser) + \ No newline at end of file diff --git a/model/CamoFormer.py b/model/CamoFormer.py new file mode 100644 index 0000000..22c256a --- /dev/null +++ b/model/CamoFormer.py @@ -0,0 +1,104 @@ +import torch +from torch import nn +from torch.utils import model_zoo + +from .encoder.swin_encoder import SwinTransformer +from .encoder.pvtv2_encoder import pvt_v2_b4 +from .decoder.decoder_p import Decoder + +from timm.models import create_model +import collections +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +import cv2 + + +def weight_init_backbone(module): + for n, m in module.named_children(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu') + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d)): + nn.init.ones_(m.weight) + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Linear): + nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu') + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Sequential): + weight_init(m) + elif isinstance(m, (nn.ReLU, nn.Sigmoid, nn.PReLU, nn.AdaptiveAvgPool2d, nn.AdaptiveAvgPool1d, nn.Sigmoid, nn.Identity)): + pass + else: + m.initialize() + +def weight_init(module): + for n, m in module.named_children(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu') + #nn.init.xavier_normal_(m.weight, gain=1) + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d)): + nn.init.ones_(m.weight) + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Linear): + nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu') + #nn.init.xavier_normal_(m.weight, gain=1) + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Sequential): + weight_init(m) + elif isinstance(m, (nn.ReLU, nn.Sigmoid, nn.PReLU, nn.AdaptiveAvgPool2d, nn.AdaptiveAvgPool1d, nn.Sigmoid, nn.Identity)): + pass + else: + m.initialize() + +class CamoFormer(torch.nn.Module): + def __init__(self, cfg, load_path=None): + super(CamoFormer, self).__init__() + self.cfg = cfg + self.encoder = pvt_v2_b4() + if load_path is not None: + pretrained_dict = torch.load(load_path) + pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in self.encoder.state_dict()} + self.encoder.load_state_dict(pretrained_dict) + print('Pretrained encoder loaded.') + + self.decoder = Decoder(128) + self.initialize() + + def _make_pred_layer(self, block, dilation_series, padding_series, NoLabels, input_channel): + return block(dilation_series, padding_series, NoLabels, input_channel) + + def forward(self, x, shape=None, name=None): + + features = self.encoder(x) + x1 = features[0] + x2 = features[1] + x3 = features[2] + x4 = features[3] + + if shape is None: + shape = x.size()[2:] + + P5, P4, P3, P2, P1= self.decoder(x1, x2, x3, x4, shape) + return P5, P4, P3, P2, P1 + + def initialize(self): + if self.cfg is not None: + if self.cfg.snapshot: + self.load_state_dict(torch.load(self.cfg.snapshot)) + else: + weight_init(self) + + + + + + + diff --git a/model/__pycache__/CamoFormer.cpython-38.pyc b/model/__pycache__/CamoFormer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6d7ad6f52833e5e4eea83aa36f1c6d559c04e04c GIT binary patch literal 3249 zcmeHJ&5j(m5hmF`)6?^_|FVtPL0q8I9g1XSa8| z*|Y4Pm3Ghr1l|p-OHO`)+2daF1csl$$DDZ0XFM z`9-Vc6ZpRKzKwrq5%M2o7C#%951|zQ0>TNWDG9JSv9hVcIlp22C?>>vBGtwlm5ICD~H2Jd;`pR*lqJu*U1$ zKPACBZ}2A68=nx~;_Xwy+k))V{c9&y@Diunq_cDZHj`e5Y9=MndX{sMhNC>cpfCVE zOWRKp`MHW@(a%*T6bcQF%J66_+_?=*?SY6v2m70`5>Y8aGe?a2^Us0?I`KY~!iSNH zOvo8|K)(9%F&on{D`~|_wr`y<^%>OmzEiP%ms8HT^-pjca5YX2>grE8I`NP{1J@pU zGtD~m`+Wxbr$(ZI`Y3U!`SlAlo1pVOj=gZr?_QwU0?mJy?TVc&TCxPv%TQLJtU_6v z@$3B>$p3f5l5!=a_LCx!MH$IhC{eIk}Gkwwe!gEDqodtNd7i=o&v)JA+WVB_l;tsO2GnD0UGdm8I805v~){U)O;r%GfKQ=MC34Rn&A;Nwl zqjb~6`XmuZu(Qy=P8+BWr^ZvyukWlX&Za=%9T zFN<1 z10<*Qz(V{5M6>}JU0yq-fx|sshaUGEVw0c-nAp_*Rjl-4#IE^=3m}iBkU&aO(r4s^ zO&Df$1?|u(?GpkuAecd{FH3h)tJoPGTgDJD<`3zB8#;7NS+@a|_X= zY;gMoXKOwqlNQd6SkZBMKMObqtXM;H(60#q>sQyOkc!6alLsLsdTSJFJUH6w&MP-a-mPXjRQ`Zt9Vb_sH>f%6j`B{kO{9gq@96EWT|SXP~+39O8p+Vj;*$E zs`r2db)ze@Kn@Rr zb+q<>91^4X^AFp2eD|Rgm>YOVl~vi5a{>o-PJ71p9oJZ@ogHxIL9oP=G%Ax^hJ{dx zD1!DN;v9i~#(TZ|VYDOCB4{OYFf8H5h@=$h^dJ#N%?|W(m_-L7g#Gd`jgE!-GunYS z38ACkEgJuvTs+S+6h*tyyJiz9H1KM#D1^IR*_{4{j2);M)GpK()HSGWsNo`n1o70p ziei*QMu=U&RiZw?p7zv-^eM#FUNjKe#@keTeR#19l_)fOdK1c4z@FZMa$B=U??A!& z7L+Y0H=lQ44PHb~BE@^KK(Pvhz_DWxz|TCo%&ybX>Vh9<;y+1YwEi*%ZUa+6(i+@& z16%u4-9UGg*2JFjXW&{&q*po>pvM#!)Q6x5VS=%&y+THVVlOY%`@kxMLiI-^2$|CW zHZgD+2qqmK!=TA$0wL_tYizW!VENU>+@ZY=>-9Pw<|Vo?47DHPlMn7{Xg9;~a2Tc2 z5rs*j?jiXr63jc@nq^Zg(iAf4$`;I~X>MWO#(W4Kukd!Ak6{i5D>Go|&T7aB3~5*($+E6_J`~D> z9^HuZG!-$XU!lF}`cmq(d8^wWN;3*CB!8hzKh};MX7KP}asWVfv}K+T=1jEv7vnv% l!3BR@9+39#O#OQoqou%e^Z;a<42q%1tzy(!-iG(m{{Yu{By|7) literal 0 HcmV?d00001 diff --git a/model/decoder/__pycache__/decoder_p.cpython-38.pyc b/model/decoder/__pycache__/decoder_p.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3c41112e10c69fb0a904dc229685de2d46369c2b GIT binary patch literal 10868 zcmcIqTaX;rS?=ph&t+$4XGhXXvSCAElZg|_yDM{Xg0U=lV+VT?k!**-p_!fT)y`_B zXLY()m!-D~vI;gKN*Thf3LbWgP#_3T5bh8#fhS&gqfWt7sG{`93q|Dt_`Y*`ZqjT@ zq^zDgf1mT8`#Jx;=Z%?}qJih*?w8wNf5tGr$-?;KpztD6{HAFbff1NpqiL$Q)w7$9 zsc~-8Rk&O)-^^z?ua~zB)94kN#SBwwmQ<^9v#j1T%^CH6l-^di(wl8ovytYSbLeSz zYrXmAJaA6n-Z7f>AQ$BEJ{EXE0q^6VG=gGKx?=>TFu!$T-)=4dQx0Z;naMCGfvE(u zz|3ZtQ@~V%Ibh~8%xPe1!8|bY8RiLK>cKH!j%Ao9fjJ(W0Omx7`7U4GHC42p zL#iRoBh`_PAst5|{=`OM-+Ip!+<~CZ9--)Cn!tPGPl1{H5 zgt8pLZ8K!lMjcH&MHC#|R*N-K(F?m|Dhu^dRJ6Lk_f zd#%;!b)u_&)EB*0*O$4QVdv^*BAxY4E0($4Q1oM8mR@Ql?aei`LC?yo*?3vG@LKD3 zC|HngKMK)2_nKOeEM5+;?}Sm(X?0~u4~z|9TIbW@Yj4Q><<8Y!zY`$S-$;6`ThhIx zD(21yt?i_9BRqfO>ZN|ai%Ey5*ITz9L}jg(0_nXPaJu9+wvBl(Z)Z1b*fNKs+UbVU zjHKDSXNadjao<`yyV>uBXLr_jqGac6@3t?(?S6cA`PtaTcc?w&Kv= z=tQmVS)O_KY!J5b7K69HeP;W%SiqojB<1e~FwB}+HtS~5teVHoih2*qWt5&ZElc&7 z;0SUDxk7OS#aRNL2+Vy86kx6G*4G!?i<^r#7eDsG;`-u;k!vir1v`M_%v)+(8axf0 zbd$cn5-=NE0EQ_(4A^G-!NYtGNljQqMo*Z6?E>qrgK2FwZ8n#lSgipBXl_dLrk>sw z--phkLGTd*p6~-pITx^cuoQ@Bqm4B243?&jFjVrnX}x1?nB@37UegZjz=1q;-Z31b z>0)`Vti1$ET@_*I59zS2wwsO2yo65>kVB7T{w@ zaRtBtb(YPXS+uHF&D=fx@Nu3Q>|x?ud7eDi3rH~uD>b)3tCY-vbJyOpaIWTo1wwL9 z7>DL*Q1X^tr!32@-Tz-aS5~pu{IxLL4mv&Q?L_Oi$T0Yi|J?lMyI=e7e}AFjNc#r# z#tO#9rml-b?YUu#j{~C!4o3|U(t42E@*90|vn7Hb zmqp*#;sW`y?_b|(b#;w8H?fN0)hc%>3r0+gCW!8}z zzLyuDL}T#^0T(=CC&7`#rwB$x>!0SbJOExf4#nE74LdyoX^#nPLh!rSw=z{uGc` zl(&t{B<1mXQmYO=-E(NUpyUK*-3V}9@QNJJVou=|Hmh@>mo1$~EB9_5t@6563Y|Ud zsl5$oXo>}K1^`OBbS)H7*!5%B8A>qS=hE)E;wSNvxi}FWcp~}jRuI6Kkh!a(zq1`{ z5mwN2{}~QPE|PX9O0<}#z>&pm(ccQ&2|b%!aI+2XM3x4){!Y?xwd}dV#`TSj$PaIA zx1t!ir zDaz@lY6B5DlesMj1Lcxh86VjPPHKUJ?7Fc>p28*3piJo`J8Ob`pbys++!&<8I#F!^ z>9#j_qHEGzmCjXodLm$h2~BWoNm{p-$0%3PJQcwvQntgQxuE=x-TCnyJ`P6fb5hca zVeTFupX^asjfC;jNF!Fu3*vYb6OWyhViM(4L&;Oo6SL`L(UXy>RFH-9NkZ4KAGO`F zVVNJY+CQN%$!d^*hHod%fvcPX-m-aUE7RWt^Ax~(z<_z+E1_j^3a&H-S03VQj)+lc zGLt1|j;yI1Su;DV8LdQFVGVap5t(RWW8!TB%EWpp_taae(i?;tW>Wh{aThDkV}exnOh#_R+!UNJjTl+;`CL5Ji!!pje(CT zipj1y^SD)qLtVE{nfO874R3VFwrdj#oGazf7m(ur0-!1PwzFsNIeYG2PVp|Zowi^1 zaQZOf4DhXT?)E&zwZ=g%D;J>f8EASl%~SZIJwh~idr3ac;|`524^&|9tGHF#Dx()~-9nE6#>+4s=I{WUkqn85g zWy?$56hXtg?p_Hiucjr)n^IbWyz$m{7a1+U{Vui}qxgYF=r97}2M){1R25foSc0xH zJB~LS1Rh31!ECv~)(rxB{-Cb6nCImQDld@`TqoEe5CoK`ngSG=(dw7Qjh(1X4uw-( z82YucfGA=VBQk?f6r1H^9W2rmRUWgDon9-3++3Gtd$b;DUY91YH>A@_l1SRnDbpn9 z=P-5_xuFTopU0o2^6-pnh~t!jVN7RtZbW>KL!$l$7n4cU`e>?=JZ+e3M_Vt7u{i=g z0+@!g#p8mwheRst`NWh01Cbh3RKO)4cy|g2xXh}+Gli-OH;?(BRq4Hc`Mi=F(hEA> z7Vqx<&tjvj#s5Oa0=Sw#lgcoWiN$#{0hLN+=Yo_rMZ1dUk64t%o9+zxvx!| zQC%JZ8W=UHyARIw3(E7PUDzmUqhlX=ZERDDNS8#lq>e)T+z_#>)Qz-ZS~T*Snz%Zp zmo{1~D7$lm!#P4caqGN*3+zjs1~p*7u*ggbm=D$o8V|}O z2NqTzqfQzJ7h7%AQ5zAzfbQZK3BCX@9w8AwL+n#S@HBI)0Kbi-2*dzg*UjDPaJi3z zpns3EWduDxoN0`nbK(wH_IZMzCwK>7RlE!MZE2lX{F_J_tyl6zR}j}mns`W!d?w6? zz*Dr2kd>l!{qbd*`B3DVbyP^LaUv4NeGQF>6om(gc^|h5Lblc0&EB~r50!!YMXW%_ z6twMLNb(0o)+yt{Vk~V~(nh9(79~;x@MVUD?wW^W4Sa*r@rJjthRk+XCc5x}K>UH` z`V|BP^POnBi%&C+isq9qvOQ&hIy&jR(j{`a zVU6tcX)N+pG>xf}wZBT)x?t`qTF~b@WRnSBh&_f}LMqCXxKiOzoiM#29Lxjzkg-SQ z$|^&gu^AZQhg8rml(+|VNR^JrF7no10ePHo;H2JyfpQULm~$%X>mDfeifAwtH0p3i zIByU~>f)|)Y5verM!kvytK1C45DbaB!3a>Mc5~Xym)?3uf1gpIS$a@PypZFSO@>QQ z3bO-*EWtSn#daUUG!(y1OfCqMR(n%NLGuiE;){v+CH8)g;FsCRZnyAJCtC^$D^aJ=48%R){7pDn}KgM|R0st@X30z$f z60QnZQ|b`Dd~L)aXC8+Dd7mrD1jw%0z$uldSR}sWLBr0qm z`Ftl%eenW}UH_$SzkTf?8R20}jvapq;6WL&1;dkGd?pKY6di`04|$=DP~`K8&30%5 zN`#k#*t#+Z+0N7LSfd=2f^N(5Lt_@l#>AKG2c*X8Jx7a;pTato?9keeUh91q0Hk%6 zRbL@ca)Sn3<{Hr2UsB>h+UpU#10t50U)K2*=2z}%c_F>_W-CH)VXDLaj~F(t0BE`^ zTThr(q}}<^DNh$0TqV;5q<9q|QEHWz1en?cQ$t|F_NQ8Ps8&^(ICrc}A0a<1<7=3X zYkN9YOSK~?lEm*4{64`wg0B*g(z4iv%t7AnBum+QIeV{2XM^JODK=N&Y38`ve@4pe ziCSii?$*Zk{5Ux5UpVK0!^Y+saoEuoCkf#n*gprrZyjD<5yuD0TKS9mtLsC2#fXmP zD{~rOQO8mzC{A%TM{QA>PhxZA7r0WBSjBMxw8q2;UIJaAqQ2y$QpRbX;xs)|At>I( zh}pM78nQEw+_Lh>EoYM|R#v(T8_}N|%h$&8^JDpXwo(Kva2jP5v^X{ecYF%!1Z z6x>NL%_)9aRdQu!tz&?wbo7Bb55_6Y5HXZ3ofmOyN_5|Ukrk9OtKxTq$JoRKd15fY0p2rIr;l!}lJRL*k> zo>LIK3k)XN7Gaht;*hL3c&f;-ti<8-xDVVBv;y@60+#=$&WA#7vsl~tc|>P(rZj6U_Wj6O5S9ZS(Fok5=pO2^X*@UyD-Z0csS=<>|4 zJfrGM!*WTL3&V0jmGi?g!dqz}Eu}N*><0gFf$y0reEqASC-G|pWa}$}Mor^{_!=>P zLhz>qe?~wl3Y&uT zU%5XXJ{0zo-ZO%@#)l9k+39WH#(%}Oi2OOQHW+P-m;p5HQB2JXj^OwRNgA##5{qfu~WF-Y2x%`2vp6 oh?fCy4Df8q@M(15Pv;fR@{W17XV b (h w) c') + +def to_4d(x,h,w): + return rearrange(x, 'b (h w) c -> b c h w',h=h,w=w) + +class BiasFree_LayerNorm(nn.Module): + def __init__(self, normalized_shape): + super(BiasFree_LayerNorm, self).__init__() + if isinstance(normalized_shape, numbers.Integral): + normalized_shape = (normalized_shape,) + normalized_shape = torch.Size(normalized_shape) + + assert len(normalized_shape) == 1 + + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.normalized_shape = normalized_shape + + def forward(self, x): + sigma = x.var(-1, keepdim=True, unbiased=False) + return x / torch.sqrt(sigma+1e-5) * self.weight + +class WithBias_LayerNorm(nn.Module): + def __init__(self, normalized_shape): + super(WithBias_LayerNorm, self).__init__() + if isinstance(normalized_shape, numbers.Integral): + normalized_shape = (normalized_shape,) + normalized_shape = torch.Size(normalized_shape) + + assert len(normalized_shape) == 1 + + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.bias = nn.Parameter(torch.zeros(normalized_shape)) + self.normalized_shape = normalized_shape + + def forward(self, x): + mu = x.mean(-1, keepdim=True) + sigma = x.var(-1, keepdim=True, unbiased=False) + return (x - mu) / torch.sqrt(sigma+1e-5) * self.weight + self.bias + + def initialize(self): + weight_init(self) + +class LayerNorm(nn.Module): + def __init__(self, dim, LayerNorm_type): + super(LayerNorm, self).__init__() + if LayerNorm_type =='BiasFree': + self.body = BiasFree_LayerNorm(dim) + else: + self.body = WithBias_LayerNorm(dim) + + def forward(self, x): + h, w = x.shape[-2:] + return to_4d(self.body(to_3d(x)), h, w) + + def initialize(self): + weight_init(self) + +class FeedForward(nn.Module): + def __init__(self, dim, ffn_expansion_factor, bias): + super(FeedForward, self).__init__() + hidden_features = int(dim*ffn_expansion_factor) + self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=bias) + self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3, stride=1, padding=1, groups=hidden_features*2, bias=bias) + self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias) + + def forward(self, x): + x = self.project_in(x) + x1, x2 = self.dwconv(x).chunk(2, dim=1) + x = F.gelu(x1) * x2 + x = self.project_out(x) + return x + + def initialize(self): + weight_init(self) + +class Attention(nn.Module): + def __init__(self, dim, num_heads, bias,mode): + super(Attention, self).__init__() + self.num_heads = num_heads + self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) + + self.qkv_0 = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) + self.qkv_1 = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) + self.qkv_2 = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) + + self.qkv1conv = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim, bias=bias) + self.qkv2conv = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim,bias=bias) + self.qkv3conv = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim,bias=bias) + + self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) + + def forward(self, x,mask=None): + b,c,h,w = x.shape + q=self.qkv1conv(self.qkv_0(x)) + k=self.qkv2conv(self.qkv_1(x)) + v=self.qkv3conv(self.qkv_2(x)) + if mask is not None: + q=q*mask + k=k*mask + + q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + + q = torch.nn.functional.normalize(q, dim=-1) + k = torch.nn.functional.normalize(k, dim=-1) + attn = (q @ k.transpose(-2, -1)) * self.temperature + attn = attn.softmax(dim=-1) + out = (attn @ v) + out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w) + out = self.project_out(out) + return out + + def initialize(self): + weight_init(self) + +class MSA_head(nn.Module): + def __init__(self, mode='dilation',dim=128, num_heads=8, ffn_expansion_factor=4, bias=False, LayerNorm_type='WithBias'): + super(MSA_head, self).__init__() + self.norm1 = LayerNorm(dim, LayerNorm_type) + self.attn = Attention(dim, num_heads, bias,mode) + self.norm2 = LayerNorm(dim, LayerNorm_type) + self.ffn = FeedForward(dim, ffn_expansion_factor, bias) + + def forward(self, x,mask=None): + x = x + self.attn(self.norm1(x),mask) + x = x + self.ffn(self.norm2(x)) + return x + + def initialize(self): + weight_init(self) + +class MSA_module(nn.Module): + def __init__(self, dim=128): + super(MSA_module, self).__init__() + self.B_TA = MSA_head() + self.F_TA = MSA_head() + self.TA = MSA_head() + self.Fuse = nn.Conv2d(3*dim,dim,kernel_size=3,padding=1) + self.Fuse2 = nn.Sequential(nn.Conv2d(dim, dim, kernel_size=1), nn.Conv2d(dim, dim, kernel_size=3, padding=1), nn.BatchNorm2d(dim), nn.ReLU(inplace=True)) + + def forward(self,x,side_x,mask): + N,C,H,W = x.shape + mask = F.interpolate(mask,size=x.size()[2:],mode='bilinear') + mask_d = mask.detach() + mask_d = torch.sigmoid(mask_d) + xf = self.F_TA(x,mask_d) + xb = self.B_TA(x,1-mask_d) + x = self.TA(x) + x = torch.cat((xb,xf,x),1) + x = x.view(N,3*C,H,W) + x = self.Fuse(x) + D = self.Fuse2(side_x+side_x*x) + return D + + def initialize(self): + weight_init(self) + +class Conv_Block(nn.Module): + def __init__(self, channels): + super(Conv_Block, self).__init__() + self.conv1 = nn.Conv2d(channels*3, channels, kernel_size=3, stride=1, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(channels) + + self.conv2 = nn.Conv2d(channels, channels*2, kernel_size=5, stride=1, padding=2, bias=False) + self.bn2 = nn.BatchNorm2d(channels*2) + + self.conv3 = nn.Conv2d(channels*2, channels, kernel_size=3, stride=1, padding=1, bias=False) + self.bn3 = nn.BatchNorm2d(channels) + + def forward(self, input1, input2, input3): + fuse = torch.cat((input1, input2, input3), 1) + fuse = self.bn1(self.conv1(fuse)) + fuse = self.bn2(self.conv2(fuse)) + fuse = self.bn3(self.conv3(fuse)) + return fuse + + def initialize(self): + weight_init(self) + + +class Decoder(nn.Module): + def __init__(self, channels): + super(Decoder, self).__init__() + + self.side_conv1 = nn.Conv2d(512, channels, kernel_size=3, stride=1, padding=1) + self.side_conv2 = nn.Conv2d(320, channels, kernel_size=3, stride=1, padding=1) + self.side_conv3 = nn.Conv2d(128, channels, kernel_size=3, stride=1, padding=1) + self.side_conv4 = nn.Conv2d(64, channels, kernel_size=3, stride=1, padding=1) + + self.conv_block = Conv_Block(channels) + + self.fuse1 = nn.Sequential(nn.Conv2d(channels*2, channels, kernel_size=3, stride=1, padding=1, bias=False),nn.BatchNorm2d(channels)) + self.fuse2 = nn.Sequential(nn.Conv2d(channels*2, channels, kernel_size=3, stride=1, padding=1, bias=False),nn.BatchNorm2d(channels)) + self.fuse3 = nn.Sequential(nn.Conv2d(channels*2, channels, kernel_size=3, stride=1, padding=1, bias=False),nn.BatchNorm2d(channels)) + + self.MSA5=MSA_module(dim = channels) + self.MSA4=MSA_module(dim = channels) + self.MSA3=MSA_module(dim = channels) + self.MSA2=MSA_module(dim = channels) + + self.predtrans1 = nn.Conv2d(channels, 1, kernel_size=3, padding=1) + self.predtrans2 = nn.Conv2d(channels, 1, kernel_size=3, padding=1) + self.predtrans3 = nn.Conv2d(channels, 1, kernel_size=3, padding=1) + self.predtrans4 = nn.Conv2d(channels, 1, kernel_size=3, padding=1) + self.predtrans5 = nn.Conv2d(channels, 1, kernel_size=3, padding=1) + + self.initialize() + + + + def forward(self, E4, E3, E2, E1,shape): + E4, E3, E2, E1= self.side_conv1(E4), self.side_conv2(E3), self.side_conv3(E2), self.side_conv4(E1) + + if E4.size()[2:] != E3.size()[2:]: + E4 = F.interpolate(E4, size=E3.size()[2:], mode='bilinear') + if E2.size()[2:] != E3.size()[2:]: + E2 = F.interpolate(E2, size=E3.size()[2:], mode='bilinear') + + E5 = self.conv_block(E4, E3, E2) + + E4 = torch.cat((E4, E5),1) + E3 = torch.cat((E3, E5),1) + E2 = torch.cat((E2, E5),1) + + E4 = F.relu(self.fuse1(E4), inplace=True) + E3 = F.relu(self.fuse2(E3), inplace=True) + E2 = F.relu(self.fuse3(E2), inplace=True) + + P5 = self.predtrans5(E5) + + D4 = self.MSA5(E5, E4, P5) + D4 = F.interpolate(D4, size=E3.size()[2:], mode='bilinear') + P4 = self.predtrans4(D4) + + D3 = self.MSA4(D4, E3, P4) + D3 = F.interpolate(D3, size=E2.size()[2:], mode='bilinear') + P3 = self.predtrans3(D3) + + D2 = self.MSA3(D3, E2, P3) + D2 = F.interpolate(D2, size=E1.size()[2:], mode='bilinear') + P2 = self.predtrans2(D2) + + D1 = self.MSA2(D2, E1, P2) + P1 =self.predtrans1(D1) + + P1 = F.interpolate(P1, size=shape, mode='bilinear') + P2 = F.interpolate(P2, size=shape, mode='bilinear') + P3 = F.interpolate(P3, size=shape, mode='bilinear') + P4 = F.interpolate(P4, size=shape, mode='bilinear') + P5 = F.interpolate(P5, size=shape, mode='bilinear') + + return P5, P4, P3, P2, P1 + + def initialize(self): + weight_init(self) + + diff --git a/model/encoder/__pycache__/pvtv2_encoder.cpython-38.pyc b/model/encoder/__pycache__/pvtv2_encoder.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..814d56583fe8884f05400f565b96a1d0477fd2b5 GIT binary patch literal 15744 zcmeHOYm6k!>zd|k zzTVI($y09{6+=(LX45blr zX%w0hm5D9sdSkLO`KadGe)f6I&j!ZAz7?}lM2h3*kdjMN_9Mmh$B;6Xrc5Cv?~fy8 zJWaU?DFuH5DHCbR0i;a&`;f9PO*!bl+~4m{J+D^|1>?xQ$v=SH1Hpk6b49Nlmb8QZ zA*3Bj(~d~mVgCryj-+WfOWMu;QKTJB(~e5o%lu%sOww-gOGqoFX~%s1ZCd%( zb&mf`S&y@ws*38>M%j#=bE@5Wq8iP|xv1@(iWWPKARdp@Vyou0+NxP?czB_pg1LGa z1t5g9$Rt9OkxqpNvVl<68PIzo2WTG8;#O1{xGCEr0T*UhZ4gxl46V_j__zowIW z-lkpCXUgMoCS2?UDt0`t-l|8Q7aOfsY(HLa1yvQBv$d14eeS7;+O12m?LS-N)2!Aa z=A4SN9Ow386k9$5ai-eoG%m;ENlniN^||>dEZeaa294QxtlsiwgKD&>f-s(#uls&5 zl8{F=!|WWI@EX<2fl@bNG_L7!&hu)GY8ZOnHSPOvJwD%V2FDlAFSeq^CR=9M>|Ch zKZ&Nhx`SUqo&`N`EafqK)BQoJ87!1uoU|+)7eeZ$wk(hS&S}}I(VZ9Yb)mAM~<0zsEBt*#APIx1>&eyA9 zY?1Ne9C(YRBQMT9&Y7RVB*r#5@01^pT?kYwXn0|LDTswvYV*|=Qo`7tQ|-kLs`0B) zHO{8;9a~MxzZG6oQJh%{RNK2|szYpDZ^qf#YRd!1sax15Ati2|Uz35^}1@huV+OH7otR=VJYhvHsMiZ^JjY)w5OQzY1kT&IOsA z44(X!#`sNX8K4yap;c`L00h?~KrS{I&U@a)#cCtTk!g*sM<0It?IJ!heCqY685od= zYb|IrJa42cFGH>>kvJgyjF8~x=!f#ogrwCF-Ej4RGMq;80Dtp>x>R?$ikT?6ij}Cg zN=8uE2I`K-xi>^n(2DBq);e2!WJcxiP`829n9#qN%9c#rVGJ??!I-ZDQ05nm@Se!< zF=L<6d;l`2_1Vmd{yxsA<{b$Az_?i8LjXd<#Im98K{&R+{@r1{60J45moDZ-ET&{HM6Is^xP5>0@x1VZ7t)v$G& zTEZv-CI}XJq&WH%^+j>OcR$8jL%kd$rCvdDhGZMK<-qR-w+BbYa!1gmj-fyc=sp2$ z1Qn(|Z7y5OnPoeH-8H>ubgga%!eDlhwqPF7)~ueWXBu#8)-nVOD$s;j{b5({=d79X z9uy^D+k$9oJ-h3qwb_0K`Dq;hb>?Q7DRc(c^=K}Q@Fhpc$D~!@c2RTf9Qf$Url~)+96Y#}x{0tV(8cBl^-#^<_5ImyYPm9HuxQ zXC7>{YZp?ypFG?DtlQ8+jUUjzC~j<@ED*j6WJF)OU0oI>2CZKR4Mi>sfgAJntJGW; z^r8U~yAR73XGL$>)O2gqlTNq)ZZ1NoYffn^VAg01;e0pOwS80c8yR^V3|<)MG;5}8 z3&yLrqBH7kB*HNRWEaI>w$x!}(ZEwjNw~l$PmE}-43)@98FA&a1V1scb)3hrXo>34 z)T%C!)JcdQ zBhXS}lsNYZ+=40^@a>qfG@f+5k7Vv<#G;U^2S{jM?HRxV6kfy~PJ%e#Ljy}9OEZ30 zm;{tDcMFtdkjN`M1u}#(nD*lWTy&@8VBqI`7jBh|Kj!D@W)a}Z0=NqCzPDToRHI72 zPi_9;=J~*1I#znDS)B_?QM)9mB}w+{tvR=ZYyCTDYds61o+5cWqXe`qz8TJ*09_A( zA<#vbJHV5UC%iCRSOdZa%K*{N(C~+~BbagDfUd^kfiXjt1Sm@zgw}|SB>;Ul&O%PW z65N2Yc+-c4lDx&xtl?W%QQmVS2fm!dud@{jGpU}w4V@KbT{gu=R}UkBelnj!Fp7DK zEE+rY=3KJEb^&a)c?pl`)-hV{k`si55&cL z80P}c(oUxK8r}%GU{9SfVr<_|BU?Z4L&t!LePqjw{{+l9%sYYEmpG!kfY~kEjOhko z8})oqXvsMbv0|Dy0&9RXg4lxow6$#D2NN8oG}sx2EsRY(4e@;Ry&TZ;U2E2a5SueY z+Y+7*VXU7+weVduV7dh@^)zB)Dp)AG7A&HM5fkoZ@7wSEc_D$ zhM1U|cW^Y*$vh{RNIXPnEcqccP7gkR;-of_KIxlC&m-jexjAcgLU47TUlg}Befe}- z-zlDay^{6CdtY&U@z_^#fx9rakoUz`4-fyiFCO|z!7uwI|K->RnDB4)Z$llE{ww@f z;q#2=7>E)5Vq+;?G@Cuc@v8mvHX>PF?Ho<*v9 zj)ca*5RGJsr@!mZi}??<3t&z8;>P$!+m z;aL#r?7c8w$dnpbGba46_O8-{4&W0NyH;iVu)S+7u-MUuAXCOcJJ!H{~ze8NbT-j{g2(dT0S!35)jK* z_8+l#Y5PGZnh!;1$N9uk7Op(GA$Dp*?DU4%nK%=*Rc$_Y8rUQ2RBPb&aIu*<8QeF^ zZp-5k;y42pITytCc{*iIDlYTV$SGOrrIFK#6XJ9lMfLO?oEbVHup5U>HtddL|ExT@ z*)A8<`xs~WBP1e5zs3;7C=sX+GjtUsE(|;$BBqJteVo}JB@s84Fmj2}JjkEp4haI9 z<-%cZo1n1Zi<_x+RJZh7E%yg*reNf+7s}-IrC06=6@UD8YWLuCHv!wbEd=hm{_yU& zCZM6k`yt%2m!9lSV{f|PV~NbY{+aP_fm%fLRBk!dK*p~O)!F%79naR#na_0HS_*OnsWIe}?2YNCcBU$Ix$*%!BmZwyJ$z z^_aLu@_CT{dh!cQ_*Ie*k^BnD#t(|&!P?NngF|UL1^_GShm67|x^&NteDa1VK;>NY zB)my9_2Fb%&^(+%5=nXlRv6uS@wf^uF4oaEZ%$Qx7+932&3&JoRY9;6Bwid3TC2>? zvwlrqdewD1x%tg4y`&V01kuQ(1j#N$WP(;=S;NPUUC4#^MhLT@;!sEF?Jd9O)qZw7!JF9jEw*Xyzg~|9nTG7 zwOw}#bKqPt6Y8_|K%E%$Mqj%7x?{WXy;!T6XtjSaOd*u+(9b4rM!FPr?4cZuxLwGN zMfg>)lcM)5*-3#FpCB#21Y0Ab-nkM^&f&-!XO7V?Ge1FJDHM;*HQMKKM6T0rH#QIW zWWxKyZ#>O!Oz_vAZ`Td9k>AqB_>OIS4{Z##U~ixeoCn!7R+G}kfvscpVg2HhXhlwo z$ij86qhFJ1yuxd2||fNQw~mTA(e3dzrd#4e73 zMfJJG_96^zoF@!gzWR0M_DE!NxL)OGP~hvavW>W+6HMLB2(@~ z8i=QuAxK2RYXc}!UqcS?fB3~4Rh8`d%nv2hHJ4s-T^veBU}rR>q!U==Fxrl|tbPZ@ z+H5z~EO%VD)sdwMj-=}9EBvadO`g_d!-4QO;%fapMt`5=4@hVL#O4A{6WaATtbhXb zM@;@>lCP3*yktErLz3f2IA(U8vy`qD&lXn5eAs`W+$QC6=xqJE?ztBa1CqoZG%DIW2erJF0Ku`7w?9JRO1IkKvW;zah z*dIikDzlQQ5T~Qa5nIMZJ8Q67c@U90&-4Nx*6^Q%XHEd&vEu?O(Y#=*GhGXM(^`;? zNAf~fpS3WC7W5(3EB9#^^IZdDXDz*{#KYYU;6y5$#2N1fG|GE9d>m@cuyAGg^f{cLaMO7AY!-&<}dbKWS@q6 zF4(=snJr&O$c?as#K9bR@vj=#$C00u&=Ix~k{z6@ZgTvZzU!yZ!g>iry^En9L+>FO zaQ!-2LtNEQe)sCXeYy9aU;onma6FqqCzQi4-~IHbczyK&xNJGT+q!$illpPyia+PW z41EElj}L#!1Yu`>{}7h1CHN5A7jVjWF5C_xKE&yv>`|c_fK>eO%%@M|g{9(PqzA}x zEwimGon_-gAsu2j`}aXM#O^tUV*JIK{gXfH*_Vk|@5TqMpT!rd_aZc$ef14S9|U<) z{RM-6NkUDb{)*(ULHd*Uw@lb(64f`Eu?@O^a_>`jaxiu3Z&=%Qlf?EvF)~Svzc`b0 z>Lr?_aR_d~AKxIN6Z})uKtzj)hz=)7eT$#Jmq{Wds`rs>KV5&vtQ(oG-rlF{)E=gb z?SF1$x)^_Prt9?o*Xdd>?mS(8f1}g&yL+Fm(|edMw*Q5Z>0vSzTIpK9 zdFSc+_Ki;0m-aqgXZA2%Y@Z&FzQ$tw#hI>G|1YPDhWK{&N^-hnBNze-{o()&_rC~W zxPIr(a~9v|oPBlgbN1>z%o*GN%E+8Cj-@M&_3|yrLe!BhzUfvg+1A_Nk513xf2jOZ zOtlfaBVqlKI44o~-9km{rBaOQ&E~ZHb1|Hj|3ik!X=JXZ{&v4A*?Zx|xcNXt1zFWJeSa8P-y_mzD J9l0Gt{~xD};NbuO literal 0 HcmV?d00001 diff --git a/model/encoder/__pycache__/swin_encoder.cpython-38.pyc b/model/encoder/__pycache__/swin_encoder.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a3d59a050add2eeb11e107d4c29b7e4408a700ed GIT binary patch literal 20350 zcmeHv+jAUOdS7=>PtO$t5CjPjqNrZ3EOEFLNKvxZD@(hg#MMfU1S?CeX-}+-hSLpT zfWZuSxAzqa1+G}TTa)4B2y>#wXsU#bx;v|(kBys!)bY7gw!qKW`hxf9B2T zVdLfnoWbuQ@C?uF8V!9lH?4+c>UeI`whYtQ%s29qVs)L(LZh%*Y!o+3jgo16$?$TX zeaG-@-&!x<&NV7t-gE94jjC7hinvZ8N69NAM>)<>LrTS~BBdIqOe1B=t0AQpr_3N_ z+M7YjOq_BADM!3nq|C-Cvq(AW%^_tjPC4qCSB(0x`)HA|T+cUYEfwWf+D}HgRy$@A|F4Z?1HDt?v2FzUOz(`@J@<>U?k$)roJG zw%$=Cd}fwCSVLf#WwVH1P0m3ToxhZ{ylpWSjxM58)=@o2+&6 z8&p{5I7hw6f$Oz4eGFW=*+frlcNwoXo7cBnUHwFHW=HmyzWCZ#q&L(_-cQB--uAoQ zW^<$t$MICLA}Wu7bqq*_MRTDfDP#S__7!nTKYbaIp|NiKvb|&N`_h>i<5&BW8c}!?W1qK0&<$5RC~u6I+!HI!MDmTj!SAacvfP(yY;Ml z=EZaFE9cy+=iHa-X*!1A>)&h!oq_M3>GZ<-v#w5ZnN*g1U-m()y^+FnjsQjWwl|yc z-B}brQrg5#{7Ge^(f1`Fz1%0^+-HB?XYbwrA&$C@S+>^N@+133$G;gBFyS`0Lq96F z`@OKUy4~Ln>IKE&jw&M`jLa7!^OeZF8ksLeg}72910}>V6Th(4Qeh|T^m~t@n1DlK zlxpTl^HEb>M7+%^B?lQxI%wB$6LQ$VvrZU84AvZm>LEEFGH#nYHU|nr2eQT>CPRsV z!yx8|`Mn$lF&;v%fI+k|h&q1`c}N6#8N~m1z#ztr`Z4;QBrwuf$xvUF(bq%GEIeLG z9=_stR@Xwezv6Z_TdRIbGFOufS36#~mOVe1Rt_ROp#t?dx>B8H@M#7d3H4J99z#&i zM>&wAI>UrIgR=}i!$*#KlJRp4mKdC8@B{-=#pIaR;)W~#hOYwuDP++jUOA;l{Ar}L zagH4c1q3f5Ig?b4oL3VWMLAq**pQ=?#AYM!J6`EKrdReVckD*Nt9nzA&qa}4C5-4y zG_8C4QW*N6<9@I8Jy5?J55YAsXwTi;?uMNw*8G;|QdYRFB(Hns>JzVDxm0&G&%#)u z=fT%oVdsYLg6Rt}yVp9cU`fW|)zF2g+Ulz?aIf|CFa>L!l@MQP^*p!N?>&*+>;75N zlYN|Rc%98;w3oLxulY)bySKF+y6v@AujhB;?~!uie%8E1)PX;TA$$ug<;^wUUDM>m zkKarPCq#Pkb(X+`EQLJ~Jow?PJ0yc{Dc`xiaiht$yJxQT``vSHe@i;Ji%+|+uDEZi zZU3Cx^3Ze_qEPi(*SfxRmK*ln>)XD1=bXFYBe>D(Zu?8_Wq+lGp+Lp)W=i3ed5B>PrIEJ40D*4 zkLLB7)D_|njFlWK6RyIL_c;n#XWW#Y9jWz&r3-1HTdKeQ$qR+{fw4WHs5h4t`CwaE zWjs-)apgTZdGPyS>LVH#^(G5U&}ins1gT}2!`z~=Qyk{EmWD;oJPIZXoy*xPgvGs* zhrC|ygfRzIcfpW5JGldcg-#}#SQ&XMV7@u#h5k@V?~xn#d|t&lH7pIwt9S=#72MZ` z70=$ccB;dw2i5ysD+yOn&S)B+F>S%vnHo;*O+(hAUi-PdnSDt~-jgzNE4jRZ_v8Ao z=AT7pJZB82#`4)C`B2Y8^Tp+%^(--M|L5W|_l|hbM|WxvFteUBtnI^$2%Yt#l>c4< zy<~gE_t@8S=+$FGM|$iy;*vwO+Ie|WkS7b_34Eta^Wi;R(jxSV%6lBShp6$^tIMl; zr2pV|HFu`m)*PF>JDN`zP%~jb(270(mhf}vsy3(O5B|d3-~G~m8hrf=NXP}DSH}f& zpng)3bhL)`9AjkXQ7N7MPy#sDv^G|Hqots>g;2foVPCb^qWr*D{UAE|U={e$X2=;H zIe4qL>Z7Qz)!vATn|`piis=cR5u6ijI1?(U+xBKOt^CzafX_6qZLh3gDpy#k9N3E@ zu}7tJibggkR#Zx-m!iOk3Rn6o;b!YrR0wpWJ}r8yN}H&jXTD;>u^7`_Mg1HesMq9a z!YLZXcE7K@pt+XDZbptyYt-Uraa$Uwyp{(|{XkI>SE8S$+OqUF70-Y|&#=m-BR}Z) zIoo{fhmJjC&X{v%)jVg;nrHAkXO_*XbsQ-*Tx;fFK})PliD=3!0`h_6p!UuIehFvL z1>p=$Xu7!(Jr{Emicbzy3yNFM4Wa$6gG21+_himNO@f}gpQl#K*^Ur10-<#R^T9H9 z3OmJ}(y)NH@^iRDthncR4wNcu#@Hz%ozKWyK<_7vy%O%qQWCE+yB;| z!QV;+z_>uCw1i655fPr4p<0~~rJIsvQmbDcvlv*yW-Hi$7JA|Q6P?})loH&2L>aI- z*3MqN8DHvdUy0yMYdE6+WgLc@;QbF_4~U+)TtA{!!B>$S6<|8kLXf$n=kVv5UJ5ba z127|BQESr^K$J((6{$Q%rR`n-iSL7d5G2sE?$0pBsaK!U(^aeso9&9M7n_l}tO!^r z`MP;MGB+agMr4y4G>)g!F6Dw`K|ltsPH3WHNrosVpw?AV&bSvYxh`X{mn4g?>u~7@Q>L@YL+U&gKV4MP5<{vFsnPtYeDO+K@;0j5x zZhhN$pEjet{DBRo{QL*CF|+OY531)feeC&qZi0>7Gj(rGSmb5@76yV~g!*-SIe?lA zqkOP9*@7(Fn2}62LS`ye-$KL#WJ7xj6{dnK3=zUqsg|>0W|Q{FCmEXPy)At$su0Bqi^8UNQ7lu3)U(8=FGvd z$>z)ZbvxFqM7^p5WGbtN`HfQy9zhTlnoY0Y2Ix>?633E)R4Jdg{GR$0gKQ8oNaCmQ z;4;pDm`2gGr)sCmpQ_~))7v=54(8zz{33!eizfvgVC&%L+v|Cp&PrJ&YeDT z<=IbenFzZBK?eQqHZ249%$Rxds&GR*-tPAKI5R9#b>v4dM$m>yNJS65*x6axVt$Pw zFIQ-wNiG*%<6aS)FMDCLyVca{zcxT%-F!K*2Jp<)SfWb{17?j6bItQQy;UY;g(Ql{ z0jARb^OjQ9icerAjhb13(Ryht#|#s~0xP4XFm>d|(A{dn8_6>~T>G+9Ek99D*A zNIfi%`y6#cYEfS9J!oFdDO!UYw)*W5cHwQa=Y-R=@9^YETEc8dtI3%oq*n{)5U)Y6 zD|z|%So32j+40N`8#>^BLp=A(*3R_KOn7{69+uhkaK>|1p!Mw>2^XZ+`HFGd3Kv+P z3Wb-G-XrGW*-CC__D9HzrI?+$;p}DOt&2OyhDULIZ0GoJZg@o7ox@|p<2trt!y29Q z9EiKQubS7pI6c^XM_7~F*3SHJes58{Gw8QSfreCPl`-cB!2Z411de*crCt)6&|3? z??k1PW9XMIDuR)+On{=hTj~YgUP1tqwcBa?0!C~2^B0&}VM}}uD>>b+C?Ks+y@Chz z6C+mUCJT{HlA94}2ldSobTquAM8ZL$`P;#yPc z7&MAlKVWTxDA(?Vk+~L`Hyc%EO4bHcm!CPn6#YKZgFlbJ$R(zzHb`eI2PWrf=)#t1 z!IESszYmKxHR7r{Ye7GzmJEZn`k|dmuC^>$EkLWT!jPRaAA=G$c=#ZlBePsZLuD;+ z(Bw}o*9*|2F*6LzZ>-V4hBaZs0;Gpk3fqq0|ArNu4Y4vK{KteHYH)2)TVi{#0htR} zj7@88J?~lIJ6k|)xN4Jd1Mq|uQL*03?G$@J%65w2d*p{D%+b=Y1Z$M@lNM|V%B=(2 zPpny>T*i9EbKbLHMV3c%$rsDufV6{4tix>^*jrMA{~|Szd#0sdH}Ul%ZOEkN@z+e+ z*|1X4u6cvc<&GAru_qQNkv9DVDqdmvgOg(hrVByn-c)_)A_Rv;mO!tTN&6_eyKY5} zOt=7RrJc<{W^&zBSNUv42|=GhTwQ0N7>HFTr2PVJBU|;mT|HZb=w%8D(GllRLghCY z+eT2YWJX>_+OmKUnXieZ*C+;Sz|Py-VWT89!=$^_DBVg3x=|H1LUXSNi~#EBe}pqQ zi-60KphU~Ybp}*^$|9{E&&}lK2212$e2i7(+4Jm)YQ z&x6cb&%x#+71_8#;;!dm`%xh;EUOj@_~7U{jS7sY$))z@U25RkH=y%1BI%Bm@h{$y%O{Akpmk)Vk07NdT_NcHE(xu1A)j1l}=Q|vA_@V z#T13`LtIwsVDe(EPC~5IafuZj17b%CPY6J5IfBFqLhJYOL|czKgqV~P`MvRP5PC^$ zN1DJ7C@ebc6FGcNRnrO08fKAj;_A9xlHf zyd+%l8I*x(YJ5WZZF$ylAY&VrdJ(OUQ<$^zDw za=2RWVvQ=kO6m#_XJ`jWp5mc-ifNwY{UmQzPASckJWH-wo^ojK<;C*hTA1yXprgli z=B%M@6Jy9uiXo$ZDTZ7boJx&KAjWhaIZNM>G*SIU4jf5A1p1(OOG6iDBJ=EEy6(1= zk7ef2#j<;EX*pXr5)uO4Oo(-n=r5o&XpvBB@3P##jNl`rwfYriy1;j=TNiLXiSr`P zrzY7a*%N6&_>^IbL^hYK8S7EFeDF61eG)jUHJBTtu#c3zoMA#`bBIjJ-uG=027@D$ zeC7FpbB>VA^Kfa(CV__{CM!+;`#LKv>M0=Lv8o*+8Cj{cVwZ6flRq(3Ev-bFL79b` z>d{99$}2I`7*~ai^HOOCIfr~&sy)gY{uG11f*_g3nbe|$nrs2n-e3!;P;(mB0BjFt z$J!xM>R$D5DRqYhml$vqKe1?f7nvX9x5+{vnu@c9;ykcVravHvSUiv5O$4KYD0R4k zHzI`+#E`R662U|fM8PEfmm-)d-!aASgMz8%O+#E|f@vlyzt{>oZ4p4N^B5xck_)d7 z;6oJNW2k4B_=`Y*@Go$&c)98yD6&9B58{$qG(G5@F-eCL^6k|G1C1fl5uqmVRqPy= zd4XIbY&+up@h1e8zShUGZ|oE@0hp!=XN&xwFa;x}Yd27(YLNC)hL?D31UWWE3-XbOXv}}}AZ#cDU@0z<zrki^P@{Ch zM|QG$pKpAVfoS}{!I;1Sl)xM85vEf^(>S1$J6&@Ad%0SKel|FIKyN&-bezZQqLsaX z2$dh8<1UmRs6FZwy!vG+nY!?G2cQ6OJBa}X5c+;@6u-+n(!YbUq}uR8PG_z(CRq~} zkAd8^?#84JM5W~}8&*R^n=*QdAB+04BMtgURWmhcSlXnjIkA*4aE`i-jDsRHO-k={ z={_4Z56>s7>LcxItZJ-q9vi9NA?hZ14s;pa<3EE)0>q%#XSB+>hdmesOs*4)w4|4x z7WC6fo`}zirPm42pVHZ)PLf)Z4wi(!j)cit^TF3yYhqynjX=$$syUb$slTY1>aUK! zFEl9R_f(ps(={c7JaFX+mvawt(-c*S1X?4E1o$AXJtH)ug^pVS> zbm=S*scaU*(s~*91RVe~$Wmlwn3pN_t5}f0LL72cW04G3S6FU#?8|Z*6$Q>-+3*#- zwq@V}x`%XDhSL-6Gi_tBRgtXnkbI(YMkmp%#Qk6dA&x3k`{Q)H786HQN{3hr80V#a z@5V(>iK7)wd+gv480c^E!{23akHK#<5Hb2&jLG~=g{M^Td&wjvE9K@SH`tLo^6O03663 z0z(%}U2;ka-@@<;N`!S%phWR{5LRd4hq#9oAbKbYSPjhy8r2=-NDvqT$il^@HByA-wM}lesa#@49COR((k{B6RQ>WK`0@vTh=M|1#6Q!azpB z8CeN=f~gl6WJbO8eKr7Oj=#eJm_fvVtJ*2^44l>$Fcg$#Bb^>gvoUp$aw3%#R0k=Y zfy_-n6OxLeLR3OtLTJ)( zjY2l8SNaX>YG1oiO*D1|HT*YhEa`;Qb_OTF_L;NFn-Z!AA?bK?U77g8L&Yvm&!*y`zF}%z4K&@V*Z2;~n?r@4zT%RQ#gc zFL)>9zUnP{4x_5W`!udc1gf6}p+6c;Pb^!s{x{_4 z18_)cSzU79c;`)ee`7<^R+q-uv+moI;K{S@D;;08kgG!=;TNFI!)s>Lr)(Sg{i=H< z-VCE%vx2u>_lfjFuDce7TfwvE&$raA&W)wMT0Otzt(?F3}nbr%T%`_j3F*C^o7enyu(1NwDVHPitdji zFJ4Lqjinyst(O)dpT+cAn<(Praj7TMG7i)+(LKutIn*79MrHacyw_7`Xng3diIeKc zroL>Vz!)I>BwQl%?;PoAfs4}mNn(;1Z-h%p!$fi$F{_MVt z`nqV<8D z(9P-ENC(8~`V?ICc&))z4~~OP0C4LLoDbo27#_h^E_*n)i={Y$*^a{1&{#jp4c@XT zYkK_{wx<5dReaszGd)>Vq&MhhAVrs-OZR&Iyc&7lznN$fm`BpsCegQ~TEa8nR=Xf~bI* z9{bmj`p49*sE#_Y?g!^4RFTwbI1p{iYqk2jcze=rq2j4N$KY86*zIO+jVZwvQu);@ zI0G&mN@J?BG;8b|x)B0#;9a{D`kTP^zK(nc?QK?UNC8Er4bdAH9z}A!HiCwJg9ZF} zKwbTGW*usCFV+L1T-WcxZ4fGh0OEo4GX4MqTlZtLksL;S8{dwq%$Dpo((Z_ROs{Ar zB|J`1ebp*ERc`eK1}`wUgdm!2#jKIr4V(H6_N3mLS7ai&d{hL65Nx&JJ^3X*{TAPK z;u;dX(S-4?B)7JaB zG*BDN|A21^szqDl2c+``B8F3jiJSndi>8xHohso#SIcrpF%K{Z&ID7A@Q_BXks_54Ux&P-CzVK0qQ6HG;;fK#Z63o$^R ztyRGY_CP)2#{z6Xoz__MJe>eD-&_bOAk>e+2`~pQ89`w9e?TIN0MG$X;Gc5n)q#tn zc|2rmT#}KjZB`XYjiSq6)Safr$wJOML!hy+D}f7XHr1 z_o8CT9sf07lu^uvawY~;QXcvgJCM(bc-!z@;nc={LC{wb*W>wtSVbd-Z}K-dX%TH` zSLs2zO1tv~+;A!idSI{S#ajgOoj#L$guLk3jLsb*ZQtIrxb2px0^;~nn)p9V%FCj< z-4{sd@1sQZ_ZU3RfaI*VeV%6QQ3UWPsWbk424@+3hN+_n)$btzp7H2PI12O9D8@ZI z`L96g-y`cQ>>{EkK&sFwkC}sqCMj)XV8)0GZ)Q*xdl@8>5EWJ|@iX7WwqHF7*&=u= z|2{5ik-_-b3D~pDy4ZLKASpMPpRC8n!eu?YDN8FL;&u63D=w)3X9he*o%F4_#KjH7 znk=}S6w>EN!Nrddn#3?>@bMv(<{SkM&2rLw_!|LzkCu%h$9jZ{n$A~H7a+Wl=m1(^ zGzg)~u>BUEsDH$Ok}>fSVt&~Pst%xRf!^L^?`S4zW2~?V7GT>9-2EQtC*o~~vnF;h z*}VLb)S-M){VQa9i=V`a0eQfEc4ppJoQ`FO}7F@wSpBGquMw)@WpUWc;o89KM5pC`_T||N{}(7(2ju_& literal 0 HcmV?d00001 diff --git a/model/encoder/pvtv2_encoder.py b/model/encoder/pvtv2_encoder.py new file mode 100644 index 0000000..f30604f --- /dev/null +++ b/model/encoder/pvtv2_encoder.py @@ -0,0 +1,445 @@ +#!/usr/bin/python3 +#coding=utf-8 + +import torch +import torch.nn as nn +import torch.nn.functional as F +from functools import partial + +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from timm.models.registry import register_model +from timm.models.vision_transformer import _cfg +from timm.models.registry import register_model + +import math + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.dwconv = DWConv(hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x, H, W): + x = self.fc1(x) + x = self.dwconv(x, H, W) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1): + super().__init__() + assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + self.q = nn.Linear(dim, dim, bias=qkv_bias) + self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.sr_ratio = sr_ratio + if sr_ratio > 1: + self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) + self.norm = nn.LayerNorm(dim) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x, H, W): + B, N, C = x.shape + q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + + if self.sr_ratio > 1: + x_ = x.permute(0, 2, 1).reshape(B, C, H, W) + x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1) + x_ = self.norm(x_) + kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + else: + kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + k, v = kv[0], kv[1] + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + + return x + + +class Block(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, + attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x, H, W): + x = x + self.drop_path(self.attn(self.norm1(x), H, W)) + x = x + self.drop_path(self.mlp(self.norm2(x), H, W)) + + return x + + +class OverlapPatchEmbed(nn.Module): + """ Image to Patch Embedding + """ + + def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + + self.img_size = img_size + self.patch_size = patch_size + self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1] + self.num_patches = self.H * self.W + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride, + padding=(patch_size[0] // 2, patch_size[1] // 2)) + self.norm = nn.LayerNorm(embed_dim) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x): + x = self.proj(x) + _, _, H, W = x.shape + x = x.flatten(2).transpose(1, 2) + x = self.norm(x) + + return x, H, W + + +class PyramidVisionTransformerImpr(nn.Module): + def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512], + num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0., + attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, + depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1]): + super().__init__() + self.num_classes = num_classes + self.depths = depths + + # patch_embed + self.patch_embed1 = OverlapPatchEmbed(img_size=img_size, patch_size=7, stride=4, in_chans=in_chans, + embed_dim=embed_dims[0]) + self.patch_embed2 = OverlapPatchEmbed(img_size=img_size // 4, patch_size=3, stride=2, in_chans=embed_dims[0], + embed_dim=embed_dims[1]) + self.patch_embed3 = OverlapPatchEmbed(img_size=img_size // 8, patch_size=3, stride=2, in_chans=embed_dims[1], + embed_dim=embed_dims[2]) + self.patch_embed4 = OverlapPatchEmbed(img_size=img_size // 16, patch_size=3, stride=2, in_chans=embed_dims[2], + embed_dim=embed_dims[3]) + + # transformer encoder + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + cur = 0 + self.block1 = nn.ModuleList([Block( + dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, + sr_ratio=sr_ratios[0]) + for i in range(depths[0])]) + self.norm1 = norm_layer(embed_dims[0]) + + cur += depths[0] + self.block2 = nn.ModuleList([Block( + dim=embed_dims[1], num_heads=num_heads[1], mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, + sr_ratio=sr_ratios[1]) + for i in range(depths[1])]) + self.norm2 = norm_layer(embed_dims[1]) + + cur += depths[1] + self.block3 = nn.ModuleList([Block( + dim=embed_dims[2], num_heads=num_heads[2], mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, + sr_ratio=sr_ratios[2]) + for i in range(depths[2])]) + self.norm3 = norm_layer(embed_dims[2]) + + cur += depths[2] + self.block4 = nn.ModuleList([Block( + dim=embed_dims[3], num_heads=num_heads[3], mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, + sr_ratio=sr_ratios[3]) + for i in range(depths[3])]) + self.norm4 = norm_layer(embed_dims[3]) + + # classification head + # self.head = nn.Linear(embed_dims[3], num_classes) if num_classes > 0 else nn.Identity() + + self.apply(self._init_weights) + self.initialize() + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def init_weights(self, pretrained=None): + if isinstance(pretrained, str): + logger = 1 + #load_checkpoint(self, pretrained, map_location='cpu', strict=False, logger=logger) + + def reset_drop_path(self, drop_path_rate): + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))] + cur = 0 + for i in range(self.depths[0]): + self.block1[i].drop_path.drop_prob = dpr[cur + i] + + cur += self.depths[0] + for i in range(self.depths[1]): + self.block2[i].drop_path.drop_prob = dpr[cur + i] + + cur += self.depths[1] + for i in range(self.depths[2]): + self.block3[i].drop_path.drop_prob = dpr[cur + i] + + cur += self.depths[2] + for i in range(self.depths[3]): + self.block4[i].drop_path.drop_prob = dpr[cur + i] + + def freeze_patch_emb(self): + self.patch_embed1.requires_grad = False + + @torch.jit.ignore + def no_weight_decay(self): + return {'pos_embed1', 'pos_embed2', 'pos_embed3', 'pos_embed4', 'cls_token'} # has pos_embed may be better + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + # def _get_pos_embed(self, pos_embed, patch_embed, H, W): + # if H * W == self.patch_embed1.num_patches: + # return pos_embed + # else: + # return F.interpolate( + # pos_embed.reshape(1, patch_embed.H, patch_embed.W, -1).permute(0, 3, 1, 2), + # size=(H, W), mode="bilinear").reshape(1, -1, H * W).permute(0, 2, 1) + + def forward_features(self, x): + B = x.shape[0] + outs = [] + + # stage 1 + x, H, W = self.patch_embed1(x) + for i, blk in enumerate(self.block1): + x = blk(x, H, W) + x = self.norm1(x) + x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() + outs.append(x) + + # stage 2 + x, H, W = self.patch_embed2(x) + for i, blk in enumerate(self.block2): + x = blk(x, H, W) + x = self.norm2(x) + x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() + outs.append(x) + + # stage 3 + x, H, W = self.patch_embed3(x) + for i, blk in enumerate(self.block3): + x = blk(x, H, W) + x = self.norm3(x) + x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() + outs.append(x) + + # stage 4 + x, H, W = self.patch_embed4(x) + for i, blk in enumerate(self.block4): + x = blk(x, H, W) + x = self.norm4(x) + x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() + outs.append(x) + + #return outs + return outs[::-1] + + # return x.mean(dim=1) + + def forward(self, x): + x = self.forward_features(x) + # x = self.head(x) + + return x + + + def initialize(self): + pass + + +class DWConv(nn.Module): + def __init__(self, dim=768): + super(DWConv, self).__init__() + self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) + + def forward(self, x, H, W): + B, N, C = x.shape + x = x.transpose(1, 2).view(B, C, H, W) + x = self.dwconv(x) + x = x.flatten(2).transpose(1, 2) + + return x + + +def _conv_filter(state_dict, patch_size=16): + """ convert patch embedding weight from manual patchify + linear proj to conv""" + out_dict = {} + for k, v in state_dict.items(): + if 'patch_embed.proj.weight' in k: + v = v.reshape((v.shape[0], 3, patch_size, patch_size)) + out_dict[k] = v + + return out_dict + + +@register_model +class pvt_v2_b0(PyramidVisionTransformerImpr): + def __init__(self, **kwargs): + super(pvt_v2_b0, self).__init__( + patch_size=4, embed_dims=[32, 64, 160, 256], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], + qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) + + + +@register_model +class pvt_v2_b1(PyramidVisionTransformerImpr): + def __init__(self, **kwargs): + super(pvt_v2_b1, self).__init__( + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], + qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) + +@register_model +class pvt_v2_b2(PyramidVisionTransformerImpr): + def __init__(self, **kwargs): + super(pvt_v2_b2, self).__init__( + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], + qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) + +@register_model +class pvt_v2_b3(PyramidVisionTransformerImpr): + def __init__(self, **kwargs): + super(pvt_v2_b3, self).__init__( + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], + qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) + +@register_model +class pvt_v2_b4(PyramidVisionTransformerImpr): + def __init__(self, **kwargs): + super(pvt_v2_b4, self).__init__( + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], + qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) + + +@register_model +class pvt_v2_b5(PyramidVisionTransformerImpr): + def __init__(self, **kwargs): + super(pvt_v2_b5, self).__init__( + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 6, 40, 3], sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, drop_path_rate=0.1) diff --git a/model/encoder/swin_encoder.py b/model/encoder/swin_encoder.py new file mode 100644 index 0000000..0f9a4bd --- /dev/null +++ b/model/encoder/swin_encoder.py @@ -0,0 +1,608 @@ +#!/usr/bin/python3 +#coding=utf-8 + +# -------------------------------------------------------- +# Swin Transformer +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Ze Liu +# -------------------------------------------------------- + +import torch +import torch.nn as nn +import torch.utils.checkpoint as checkpoint +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + + +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + # calculate attention mask for SW-MSA + H, W = self.input_resolution + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def forward(self, x): + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA + attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + def extra_repr(self) -> str: + return f"input_resolution={self.input_resolution}, dim={self.dim}" + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.dim + flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim + return flops + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x) + if self.downsample is not None: + x = self.downsample(x) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops + + +class PatchEmbed(nn.Module): + r""" Image to Patch Embedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + B, C, H, W = x.shape + # FIXME look at relaxing size constraints + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C + if self.norm is not None: + x = self.norm(x) + return x + + def flops(self): + Ho, Wo = self.patches_resolution + flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops + + +class SwinTransformer(nn.Module): + r""" Swin Transformer + A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - + https://arxiv.org/pdf/2103.14030 + + Args: + img_size (int | tuple(int)): Input image size. Default 224 + patch_size (int | tuple(int)): Patch size. Default: 4 + in_chans (int): Number of input image channels. Default: 3 + num_classes (int): Number of classes for classification head. Default: 1000 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, + embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], + window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, + norm_layer=nn.LayerNorm, ape=False, patch_norm=True, + use_checkpoint=False, **kwargs): + super().__init__() + + self.num_classes = num_classes + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)] + self.mlp_ratio = mlp_ratio + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # absolute position embedding + if self.ape: + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build layers + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer), + input_resolution=(patches_resolution[0] // (2 ** i_layer), + patches_resolution[1] // (2 ** i_layer)), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, + use_checkpoint=use_checkpoint) + self.layers.append(layer) + + # self.norm = norm_layer(self.num_features) + # self.avgpool = nn.AdaptiveAvgPool1d(1) + # self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + + self.apply(self._init_weights) + self.initialize() + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'relative_position_bias_table'} + + def forward_features(self, x): + num_passed = 0 + features = [] + + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + for layer in self.layers: + features.append(self.resize_feat(x, num_passed)) + num_passed += 1 + x = layer(x) + + features.append(self.resize_feat(x, num_passed-1)) + + # x = self.norm(x) # B L C + # x = self.avgpool(x.transpose(1, 2)) # B C 1 + # x = torch.flatten(x, 1) + return features + + def resize_feat(self, x, num_passed): + sizes = [96, 48, 24, 12, 12] + size = sizes[num_passed] + resize_x = x.view(-1, size, size, self.num_features[num_passed]).permute(0, 3, 1, 2).contiguous() + return resize_x + + def forward(self, x): + x = self.forward_features(x) + # x = self.head(x) + + return x[::-1] + + def flops(self): + flops = 0 + flops += self.patch_embed.flops() + for i, layer in enumerate(self.layers): + flops += layer.flops() + flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers) + flops += self.num_features * self.num_classes + return flops + + def initialize(self): + pass + + diff --git a/test_eval.sh b/test_eval.sh new file mode 100644 index 0000000..365d777 --- /dev/null +++ b/test_eval.sh @@ -0,0 +1,8 @@ +python main.py \ + --ckpt 'checkpoint/CamoFormer56'\ + --mode 'test' + +python evaltools/eval.py \ + --GT_root 'dataset/TestDataset' \ + --pred_root 'output/Prediction/CamoFormer-test'\ + --BR 'on' \ No newline at end of file diff --git a/train.sh b/train.sh new file mode 100644 index 0000000..33a3b08 --- /dev/null +++ b/train.sh @@ -0,0 +1,12 @@ +python main.py \ + --model 'CamoFormer' \ + --dataset 'dataset/TrainDataset' \ + --test_dataset 'dataset/TestDataset/' \ + --pretrain_path 'checkpoint/pvt_v2_b4.pth' \ + --lr 1e-2 \ + --decay 2e-4 \ + --momen 0.9 \ + --batchsize 8 \ + --savepath 'output/checkpoint/CamoFormer/CamoFormer/' \ + --valid True +