关于Lora的手写实现
编辑
5
2025-06-04

nn.Linear
class LoraLinear(nn.Module):
def __init__(self, baselinear, rank, alpha, dropout):
super().__init__()
self.rank = rank
self.alpha = alpha
self.dropout = nn.Dropout(dropout)
self.base_linear = copy.deepcopy(baselinear)
# 使用 nn.Linear 替代直接参数
# A 矩阵: rank x in_features
self.lora_A = nn.Linear(self.base_linear.in_features, self.rank, bias=False)
# B 矩阵: rank x out_features
self.lora_B = nn.Linear(self.rank, self.base_linear.out_features, bias=False)
# 初始化
nn.init.normal_(self.lora_A.weight, mean=0, std=0.02)
nn.init.zeros_(self.lora_B.weight)
# 冻结基础模型参数
for param in self.base_linear.parameters():
param.requires_grad = False
def forward(self, x):
scaling = self.alpha / self.rank
# 使用 lora_A 和 lora_B 的前向传播
m = self.lora_B(self.dropout(self.lora_A(x)))
return self.base_linear(x) + scaling * m
def get_lora_model(module, rank, alpha, dropout):
for name, child in module.named_children():
if any(s in name for s in ["embed", "norm", "lm_head"]):
for param in child.parameters():
param.requires_grad = False
elif isinstance(child, nn.Linear):
lora_linear = LoraLinear(child, rank, alpha, dropout)
setattr(module, name, lora_linear)
else:
get_lora_model(child, rank, alpha, dropout)
return module # 返回修改后的模块
nn.Parameter
class LoraLinear(nn.Module):
def __init__(self, baselinear, rank, alpha, dropout):
super().__init__()
self.rank = rank
self.alpha = alpha
self.dropout = nn.Dropout(dropout)
self.base_linear = copy.deepcopy(baselinear)
self.lora_A = nn.Parameter(torch.empty(self.rank, self.base_linear.in_features, dtype=self.base_linear.weight.dtype))
self.lora_B = nn.Parameter(torch.empty(self.base_linear.out_features, self.rank, dtype=self.base_linear.weight.dtype))
nn.init.normal_(self.lora_A, mean=0.02)
nn.init.zeros_(self.lora_B)
for param in self.base_linear.parameters():
param.requires_grad = False
def forward(self, x):
scaling = self.alpha / self.rank
m = F.linear(self.dropout(x), self.lora_A)
m = F.linear(m, self.lora_B)
return self.base_linear(x) + scaling * m
def get_lora_model(module, rank, alpha, dropout):
for name, child in module.named_children():
if any(s in name for s in ["embed", "norm", "lm_head"]):
for param in child.parameters():
param.requires_grad = False
elif isinstance(child, nn.Linear):
lora_linear = LoraLinear(child, rank, alpha, dropout)
setattr(module, name, lora_linear)
else:
get_lora_model(child, rank, alpha, dropout)
return module # 返回修改后的模块
- 0
- 0
-
分享