yairschiff
commited on
Commit
•
97808c5
1
Parent(s):
c699cf9
Update modeling_rcps.py
Browse files- modeling_rcps.py +5 -0
modeling_rcps.py
CHANGED
@@ -148,6 +148,11 @@ class RCPSMambaBlock(nn.Module):
|
|
148 |
self.mixer = RCPSWrapper(mixer_cls(dim))
|
149 |
norm_f = norm_cls(dim)
|
150 |
self.norm = norm_f if fused_add_norm else RCPSAddNormWrapper(norm_f)
|
|
|
|
|
|
|
|
|
|
|
151 |
|
152 |
def forward(
|
153 |
self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None
|
|
|
148 |
self.mixer = RCPSWrapper(mixer_cls(dim))
|
149 |
norm_f = norm_cls(dim)
|
150 |
self.norm = norm_f if fused_add_norm else RCPSAddNormWrapper(norm_f)
|
151 |
+
if self.fused_add_norm:
|
152 |
+
assert RMSNorm is not None, "RMSNorm import fails"
|
153 |
+
assert isinstance(
|
154 |
+
self.norm, (nn.LayerNorm, RMSNorm)
|
155 |
+
), "Only LayerNorm and RMSNorm are supported for fused_add_norm"
|
156 |
|
157 |
def forward(
|
158 |
self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None
|