波波算法笔记

Bob Peng

关于Lora的手写实现

2025-06-04
关于Lora的手写实现

nn.Linear

class LoraLinear(nn.Module):
    def __init__(self, baselinear, rank, alpha, dropout):
        super().__init__()
        self.rank = rank
        self.alpha = alpha
        self.dropout = nn.Dropout(dropout)
        self.base_linear = copy.deepcopy(baselinear)
        
        # 使用 nn.Linear 替代直接参数
        # A 矩阵: rank x in_features
        self.lora_A = nn.Linear(self.base_linear.in_features, self.rank, bias=False)
        # B 矩阵: rank x out_features
        self.lora_B = nn.Linear(self.rank, self.base_linear.out_features, bias=False)
        
        # 初始化
        nn.init.normal_(self.lora_A.weight, mean=0, std=0.02)
        nn.init.zeros_(self.lora_B.weight)
        
        # 冻结基础模型参数
        for param in self.base_linear.parameters():
            param.requires_grad = False
            
    def forward(self, x):
        scaling = self.alpha / self.rank
        
        # 使用 lora_A 和 lora_B 的前向传播
        m = self.lora_B(self.dropout(self.lora_A(x)))
        
        return self.base_linear(x) + scaling * m
  
  
def get_lora_model(module, rank, alpha, dropout):
    for name, child in module.named_children():
        if any(s in name for s in ["embed", "norm", "lm_head"]):
            for param in child.parameters():
                param.requires_grad = False
        elif isinstance(child, nn.Linear):
            lora_linear = LoraLinear(child, rank, alpha, dropout)
            setattr(module, name, lora_linear)
        else:
            get_lora_model(child, rank, alpha, dropout)
    
    return module  # 返回修改后的模块

nn.Parameter

class LoraLinear(nn.Module):
    def __init__(self, baselinear, rank, alpha, dropout):
        super().__init__()
        self.rank = rank
        self.alpha = alpha
        self.dropout = nn.Dropout(dropout)
        self.base_linear = copy.deepcopy(baselinear)
        self.lora_A = nn.Parameter(torch.empty(self.rank, self.base_linear.in_features, dtype=self.base_linear.weight.dtype))
        self.lora_B = nn.Parameter(torch.empty(self.base_linear.out_features, self.rank, dtype=self.base_linear.weight.dtype))
        nn.init.normal_(self.lora_A, mean=0.02)
        nn.init.zeros_(self.lora_B)
        for param in self.base_linear.parameters():
            param.requires_grad = False
            
    def forward(self, x):
        scaling = self.alpha / self.rank
        m = F.linear(self.dropout(x), self.lora_A)
        m = F.linear(m, self.lora_B)
        return self.base_linear(x) + scaling * m
  
  
def get_lora_model(module, rank, alpha, dropout):
    for name, child in module.named_children():
        if any(s in name for s in ["embed", "norm", "lm_head"]):
            for param in child.parameters():
                param.requires_grad = False
        elif isinstance(child, nn.Linear):
            lora_linear = LoraLinear(child, rank, alpha, dropout)
            setattr(module, name, lora_linear)
        else:
            get_lora_model(child, rank, alpha, dropout)
    
    return module  # 返回修改后的模块