diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index e1120504d7..281f41753a 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -522,8 +522,7 @@ def __init__( fc_kwargs: dict[str, Any] = { 'bias': bias, } - if fc_type != 'te': - fc_kwargs['device'] = device + fc_kwargs['device'] = device self.Wqkv = FC_CLASS_REGISTRY[fc_type]( self.d_model, self.d_model + 2 * self.kv_n_heads * self.head_dim, diff --git a/llmfoundry/models/layers/ffn.py b/llmfoundry/models/layers/ffn.py index fa3e109bf8..5e99e0a960 100644 --- a/llmfoundry/models/layers/ffn.py +++ b/llmfoundry/models/layers/ffn.py @@ -97,8 +97,8 @@ def __init__( self.fc_kwargs: dict[str, Any] = { 'bias': bias, } - if fc_type != 'te': - self.fc_kwargs['device'] = device + + self.fc_kwargs['device'] = device self.up_proj = FC_CLASS_REGISTRY[fc_type]( d_model,