fix rmsnorm init weight bug.

#59
Files changed (1) hide show
  1. modeling_chatglm.py +1 -1
modeling_chatglm.py CHANGED
@@ -181,7 +181,7 @@ def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Ten
181
  class RMSNorm(torch.nn.Module):
182
  def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs):
183
  super().__init__()
184
- self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype))
185
  self.eps = eps
186
 
187
  def forward(self, hidden_states: torch.Tensor):
 
181
  class RMSNorm(torch.nn.Module):
182
  def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs):
183
  super().__init__()
184
+ self.weight = torch.nn.Parameter(torch.ones(normalized_shape, device=device, dtype=dtype))
185
  self.eps = eps
186
 
187
  def forward(self, hidden_states: torch.Tensor):