yairschiff commited on
Commit
97808c5
1 Parent(s): c699cf9

Update modeling_rcps.py

Browse files
Files changed (1) hide show
  1. modeling_rcps.py +5 -0
modeling_rcps.py CHANGED
@@ -148,6 +148,11 @@ class RCPSMambaBlock(nn.Module):
148
  self.mixer = RCPSWrapper(mixer_cls(dim))
149
  norm_f = norm_cls(dim)
150
  self.norm = norm_f if fused_add_norm else RCPSAddNormWrapper(norm_f)
 
 
 
 
 
151
 
152
  def forward(
153
  self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None
 
148
  self.mixer = RCPSWrapper(mixer_cls(dim))
149
  norm_f = norm_cls(dim)
150
  self.norm = norm_f if fused_add_norm else RCPSAddNormWrapper(norm_f)
151
+ if self.fused_add_norm:
152
+ assert RMSNorm is not None, "RMSNorm import fails"
153
+ assert isinstance(
154
+ self.norm, (nn.LayerNorm, RMSNorm)
155
+ ), "Only LayerNorm and RMSNorm are supported for fused_add_norm"
156
 
157
  def forward(
158
  self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None