zRzRzRzRzRzRzR Shan1990 commited on
Commit
91a0561
1 Parent(s): b3121c8

fix rmsnorm init weight bug. (#59)

Browse files

- fix rmsnorm init weight bug. (9d3d7be563d07295abb119ff28714aa9267580b8)


Co-authored-by: Ben <[email protected]>

Files changed (1) hide show
  1. modeling_chatglm.py +1 -1
modeling_chatglm.py CHANGED
@@ -181,7 +181,7 @@ def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Ten
181
  class RMSNorm(torch.nn.Module):
182
  def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs):
183
  super().__init__()
184
- self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype))
185
  self.eps = eps
186
 
187
  def forward(self, hidden_states: torch.Tensor):
 
181
  class RMSNorm(torch.nn.Module):
182
  def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs):
183
  super().__init__()
184
+ self.weight = torch.nn.Parameter(torch.ones(normalized_shape, device=device, dtype=dtype))
185
  self.eps = eps
186
 
187
  def forward(self, hidden_states: torch.Tensor):