diff --git a/llms/mlx_lm/generate.py b/llms/mlx_lm/generate.py index c003940be..524319f6b 100644 --- a/llms/mlx_lm/generate.py +++ b/llms/mlx_lm/generate.py @@ -150,7 +150,7 @@ def main(): tokenizer, prompt, args.max_tokens, - verbose=True, + verbose=mx.distributed.init().rank() == 0, formatter=formatter, temp=args.temp, top_p=args.top_p, diff --git a/llms/mlx_lm/models/llama.py b/llms/mlx_lm/models/llama.py index 2a49ee377..bb28aa5fb 100644 --- a/llms/mlx_lm/models/llama.py +++ b/llms/mlx_lm/models/llama.py @@ -207,6 +207,36 @@ def sanitize(self, weights): k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k } + def shard(self, group: Optional[mx.distributed.Group] = None): + group = group or mx.distributed.init() + + def all_to_sharded(l): + if isinstance(l, nn.QuantizedLinear): + return nn.QuantizedAllToShardedLinear.from_quantized_linear(l, group) + else: + return nn.AllToShardedLinear.from_linear(l, group) + + def sharded_to_all(l): + if isinstance(l, nn.QuantizedLinear): + return nn.QuantizedShardedToAllLinear.from_quantized_linear(l, group) + else: + return nn.ShardedToAllLinear.from_linear(l, group) + + N = group.size() + for layer in self.model.layers: + # Shard the self attention + layer.self_attn.q_proj = all_to_sharded(layer.self_attn.q_proj) + layer.self_attn.k_proj = all_to_sharded(layer.self_attn.k_proj) + layer.self_attn.v_proj = all_to_sharded(layer.self_attn.v_proj) + layer.self_attn.o_proj = sharded_to_all(layer.self_attn.o_proj) + layer.self_attn.n_heads //= N + layer.self_attn.n_kv_heads //= N + + # Shard the MLP + layer.mlp.gate_proj = all_to_sharded(layer.mlp.gate_proj) + layer.mlp.down_proj = sharded_to_all(layer.mlp.down_proj) + layer.mlp.up_proj = all_to_sharded(layer.mlp.up_proj) + @property def layers(self): return self.model.layers @@ -217,4 +247,4 @@ def head_dim(self): @property def n_kv_heads(self): - return self.args.num_key_value_heads + return self.args.num_key_value_heads // mx.distributed.init().size() diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 229ee2381..dfd5c3b7c 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -415,6 +415,11 @@ def class_predicate(p, m): model.load_weights(list(weights.items())) + if mx.distributed.init().size() > 1: + if not hasattr(model, "shard"): + raise RuntimeError("Model doesn't support distributed inference.") + model.shard() + if not lazy: mx.eval(model.parameters())