1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123
| import torch import torch.nn as nn
class PatchEmbedding(nn.Module): def __init__(self,in_channels,patch_size,embed_dim,image_size): super().__init__() self.patch_num=(image_size//patch_size)**2 self.emb=nn.Conv2d(in_channels,embed_dim,kernel_size=patch_size,stride=patch_size) self.pos_emb=nn.Parameter(torch.randn(1,self.patch_num+1,embed_dim)) self.cls_emb=nn.Parameter(torch.randn(1,1,embed_dim))
def forward(self,x): x=self.emb(x) b,d,h,w=x.shape x=x.permute(0,2,3,1).reshape(b,h*w,d) cls_emb=self.cls_emb.expand(b,-1,-1) x=torch.cat((cls_emb,x),dim=1) out=x+self.pos_emb[:,:1+h*w] return out
class MultiHeadSelfAttention(nn.Module): def __init__(self,embed_dim,num_heads,dropout=0.0): super().__init__() self.embed_dim=embed_dim self.num_heads=num_heads self.head_dim=embed_dim//num_heads self.proj_q=nn.Linear(embed_dim,embed_dim) self.proj_k=nn.Linear(embed_dim,embed_dim) self.proj_v=nn.Linear(embed_dim,embed_dim) self.out_proj=nn.Linear(embed_dim,embed_dim) self.drop=nn.Dropout(dropout) def forward(self,x): b,l,_=x.shape q = self.proj_q(x).reshape(b, l, self.num_heads, self.head_dim).transpose(1, 2) k = self.proj_k(x).reshape(b, l, self.num_heads, self.head_dim).transpose(1, 2) v = self.proj_v(x).reshape(b, l, self.num_heads, self.head_dim).transpose(1, 2) atten_score=torch.matmul(q,k.transpose(-2,-1))/self.head_dim**0.5 atten_weight=torch.nn.functional.softmax(atten_score,dim=-1) atten_weight = self.drop(atten_weight) atten = torch.matmul(atten_weight, v).transpose(1, 2).reshape(b, l, self.embed_dim) out=self.out_proj(atten) out=self.drop(out) return out
class MlpBlock(nn.Module): def __init__(self,embed_dim,mlp_dim,out_dim,dropout=0.0): super().__init__() self.fc1=nn.Linear(embed_dim,mlp_dim) self.fc2=nn.Linear(mlp_dim,out_dim) self.act=nn.GELU() self.drop=nn.Dropout(dropout) def forward(self,x): x=self.fc1(x) x=self.act(x) x=self.drop(x) x=self.fc2(x) out=self.drop(x) return out
class TransformerEncoderBlock(nn.Module): def __init__(self,embed_dim,num_heads,mlp_dim,dropout=0.0): super().__init__() self.norm1=nn.LayerNorm(embed_dim) self.norm2=nn.LayerNorm(embed_dim) self.mlp = MlpBlock(embed_dim, mlp_dim, embed_dim, dropout) self.msa = MultiHeadSelfAttention(embed_dim, num_heads, dropout) def forward(self,x): norm=self.norm1(x) msa=self.msa(norm) res=msa+x norm=self.norm2(res) mlp=self.mlp(norm) out=mlp+res return out
class TransformerEncoder(nn.Module): def __init__(self,embed_dim,num_heads,mlp_dim,depth,dropout=0.0): super().__init__() self.layers=nn.ModuleList([]) for _ in range(depth): layer = TransformerEncoderBlock(embed_dim, num_heads, mlp_dim, dropout) self.layers.append(layer) self.norm=nn.LayerNorm(embed_dim)
def forward(self,x): for layer in self.layers: x = layer(x) out=self.norm(x) return out
class VisionTransformer(nn.Module): def __init__(self,image_size,in_channels,num_classes,patch_size,embed_dim,depth,num_heads,mlp_dim,dropout=0.0): super().__init__()
self.emb=PatchEmbedding(in_channels,patch_size,embed_dim,image_size) self.trans=TransformerEncoder(embed_dim,num_heads,mlp_dim,depth,dropout) self.lin=nn.Linear(embed_dim,num_classes) def forward(self,x): emb=self.emb(x) trans_output=self.trans(emb) cls_token_output = trans_output[:, 0] out=self.lin(cls_token_output) return out
|