本文记录一下将PyTorch模型适配到MLX的过程。

什么是MLX?MLX is an array framework for machine learning on Apple silicon, brought to you by Apple machine learning research.https://github.com/ml-explore/mlx/blob/main/README.md

MLX 是适应于苹果M系列芯片(Apple Silicon)的机器学习框架。

mlx的array设计更加接近于numpy[1],而不是PyTorch的tensor,即只存有结构信息(如形状、数据类型等),没有其它与深度学习训练相关的属性(如梯度)。与numpy和torch不同的是,mlx 的array是 Unified Memory 可以在CPU和GPU之间共享,这也是mlx被单独开发而非拓展pytorch的 mps backend的理由[2]。

array与tensor的区别

np.array和torch.tensor的区别可以阅读What is a Tensor in Machine Learning?

模型转换实践BigVGAN PyTorch -> mlx-BigVGAN

基本的映射torchmlxDataTypesData typesData TypesNNtorch.nn.*mlx.nn.*Parameters/Weight/Buffertorch.Tensormlx.core.arrayModuleListtorch.nn.ModuleListlistModuleDicttorch.nn.ModuleDictdictTransformtorch.fftmlx.core.fftPad modetorch 的 pad 方法支持constant, reflect, replicate等模式,而mlx的pad则支持constant、edge模式。

reflect模式在mlx中并不支持,可以自定义实现:

1234567891011import mlx.core as mxdef pad_reflect(x: mx.array, padding: tuple | int) -> mx.array: """ pad the input array with `reflect` mode in last axis """ if isinstance(padding, int): padding = (padding, padding) prefix = x[..., 1 : padding[0] + 1][..., ::-1] suffix = x[..., -(padding[1] + 1) : -1][..., ::-1] return mx.concatenate([prefix, x, suffix], axis=-1)假设 x 为 2D (M, N) 的张量,在最后一个维度的右边增加宽度为1的padding,torch和mlx的pad 对照如下:

pad modetorchmlxconstantF.pad(x, (0, 1), mode="constant", value=0)mx.pad(x, [(0, 0), (0, 1)], mode="constant", constant_values=1)replicateF.pad(x, (0, 1), mode="replicate")mx.pad(x, [(0, 0), (0, 1)], mode="edge")reflectF.pad(x, (0, 1), mode="reflect")pad_reflect(x, (0, 1)) (custom pad_reflect function)nn.Modulemlx的nn.Module,可以不用定义forward方法,直接使用__call__。

如有一个 Snake 的 Module,torch和mlx的实现如下:

torchmlx1234567891011121314151617181920212223import torchfrom torch import nn, sin, powfrom torch.nn import Parameterclass Snake(nn.Module): def __init__( self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False ): super(Snake, self).__init__() self.in_features = in_features self.alpha_logscale = alpha_logscale if self.alpha_logscale: self.alpha = Parameter(torch.zeros(in_features) * alpha) else: self.alpha = Parameter(torch.ones(in_features) * alpha) self.alpha.requires_grad = alpha_trainable self.no_div_by_zero = 0.000000001 def forward(self, x): alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # Line up with x to [B, C, T] if self.alpha_logscale: alpha = torch.exp(alpha) x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2) return x12345678910111213141516171819202122232425import mlx.core as mximport mlx.nn as nnclass Snake(nn.Module): def __init__( self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False ): super(Snake, self).__init__() self.in_features = in_features self.alpha_logscale = alpha_logscale if self.alpha_logscale: self.alpha = mx.zeros(in_features) * alpha else: self.alpha = mx.ones(in_features) * alpha if not alpha_trainable: self.freeze(keys="alpha") self.no_div_by_zero = 0.000000001 def __call__(self, x): # Line up with x to [B, T, C] alpha = mx.expand_dims(self.alpha, axis=(0,-1)) if self.alpha_logscale: alpha = mx.exp(alpha) x = x + (1.0 / (alpha + self.no_div_by_zero)) * mx.power(mx.sin(x * alpha), 2) return x差异点

mlx用mlx.nn.Module.freeze[3]方法来冻结无需训练的参数,torch则用torch.Tensor.requires_grad_[4]来控制是否需要训练。

weight_normtorch 提供了一个weight_norm 方法,用于优化训练的稳定性和泛化能力。weight_norm将权重分解为两个部分:v和g,其中v是一个向量,g是一个标量。权重的计算公式如下:

w=gv∥v∥w = g \frac{v}{\|{v}\|}w=g​∥v∥​​v​​然而mlx暂时没有提供类似的功能(社区的PR还在review中 Implement Weight Normalization还好,我们可以在权重转换时做一次类似remove_weight_norm的操作,针对仅需推理的模型。。。

12345678910111213141516import torchorigin_state_dict = torch.load("origin_model.pth", map_location="cpu", weight_only=True)out_weights = {}for k, v in origin_state_dict.items(): # handle weight norm if k.endswith(("weight_v", "weight_g")): basename, pname = k.rsplit(".", 1) if pname == "weight_v": g = origin_state_dict[basename + ".weight_g"] # compute weight k = basename + ".weight" v = torch._weight_norm(v, g, dim=0) else: # pname == "weight_g" continue ... out_weights[k] =vConv1dmlx的Conv1d和torch的接口设计基本一致,主要差异在于torch的Conv1d的输入数据格式为(B, C, L),而mlx的输入数据格式为(B, L, C)。(B为batch size,C为通道数,L为序列长度如时间帧数量)torch的Conv1d的权重格式为(out_channels, in_channels // groups, kernel_size),而mlx的Conv1d的权重格式为(out_channels, kernel_size, in_channels // groups)。bias的形状则是一样的,都是(out_channels)。

torchmlx12345678910111213import torchconv = nn.Conv1d( in_channels, out_channels, kernel_size, stride=1, dilation=dilation, padding=(kernel_size * dilation - dilation) // 2,)print(conv.weight.shape) # (out_channels, in_channels, kernel_size)x = ... # in shape (B, in_channels, seq_len)y = conv(x)123456789101112import mlx.core as mxconv = nn.Conv1d( in_channels, out_channels, kernel_size, stride=1, dilation=dilation, padding=(kernel_size * dilation - dilation) // 2,)print(conv.weight.shape) # (out_channels, kernel_size, in_channels)x = ... # in shape (B, seq_len, in_channels)y = conv(x)权重转换:

12345torch_conv1d_weight = ... # Tensor in shape: (out_channels, in_channels, kernel_size)mlx_conv_weight = mx.array(torch_conv1d_weight.permute(0, 2, 1)) # (out_channels, kernel_size, in_channels)# or mlx_conv_weight = mx.array(torch_conv1d_weight.moveaxis(1, 2)) # (out_channels, kernel_size, in_channels)Conv1dtorchmlxInput shape(B, C_in, L_seq)(B, L_seq, C_in)Weight shape(C_out, C_in // groups, kernel_size)(C_out, kernel_size, C_in // groups)ConvTranspose1DConvTranspose1D 的输入和权重的差异点也和 Conv1d 类似

ConvTranspose1DtorchmlxInput shape(B, C_in, L_seq)(B, L_seq, C_in)Weight shape(C_in, C_out, kernel_size)(C_out, kernel_size, C_in)特别注意,mlx ConvTranspose1D 不支持 groups参数,即 groups=1。

性能并不如预期在将BigVGAN成功适配到mlx之后,用我这台 Apple M3 (16G) Macbook Pro 14” 与原版实现进行对比,好家伙,发现还不如原来的pytorch实现(白忙活了一场…)

BigVGAN: 2.3289 seconds per inferenceMLX BigVGAN: 4.5342 seconds per inferenceCompiled MLX BigVGAN: 4.3205 seconds per inference

进一步 Profile 发现[5],mlx 的 conv1d 和 conv_transpose1d 的性能不如 torch mps backend,提了个Issue mlx#2180, 截至当前没得到回复。

conv1d input: 8x256x1000 weight: 256x32x12 groups: 8conv_transpose1d input: 8x256x1000 weight: 256x1x12 groups: 256torch(mps) conv1d: 0.943 msmlx_conv1d: 3.906 msdiff: -2.9631301250046818torch(mps) conv_transpose1d: 2.912 msmlx conv_transpose1d: 5.282 msdiff: -2.3704653340100776

由此可见,MLX 还有很大的提升空间~ 期待官方的更新和社区的贡献。

RefsMLXMLX DocsPyTorch DocsMLX Implement Weight Normalization1.numpy NumPy is the fundamental package for scientific computing in Python. ↩2.awni's answer to Why not implement this in Pytorch? ↩3.freeze Freeze the Module’s parameters or some of them ↩4.torch.Tensor.requires_grad_ Sets the requires_grad attribute of the tensor ↩5.benchmark scripts for conv1d and conv_transpose1d ↩Author: Yrom

Link: https://yrom.net/blog/2025/05/14/adapt-a-pytorch-model-to-mlx/

License: 知识共享署名-非商业性使用 4.0 国际许可协议