找回密码
立即注册
搜索
热搜: Java Python Linux Go
发回帖 发新帖

1552

积分

0

好友

223

主题
发表于 5 天前 | 查看: 17| 回复: 0

Scikit-Learn 1.8.0 的更新带来了一个实验性特性:正式支持 Python Array API 标准。这意味着,CuPy 数组或 PyTorch 张量现在可以直接传递给部分 Scikit-Learn 组件,并且整个计算流程能够保留在 GPU 等非 CPU 设备上执行,无需像过去那样强制转换回 NumPy 数组。

图片

1.8.0 版本的核心更新

此次更新围绕 Python Array API 标准展开,这是一个由 NumPy、CuPy、PyTorch、JAX 等主流数组库共同维护的接口规范。在 1.8.0 版本中,它主要实现了以下功能:

  • 直接传参:受支持的评估器可以直接接收 CuPy 数组或 PyTorch 张量作为输入。
  • 计算分派:算法运算会被自动分派到输入数据所在的设备(如 GPU)上执行。
  • 状态保留:模型训练后产生的属性(如coef_)会与输入数据保持在同一物理设备上。

尽管该功能目前仍标记为“实验性”且需手动开启,但它标志着 Scikit-Learn 开始突破长期以来对 NumPy 数组的强依赖,拥抱更广泛的 Python 科学计算生态。

对交叉验证流程的性能革新

如果你不常使用cross_val_scoreGridSearchCVCalibratedClassifierCV,可能对此更新感知不强。但对于大多数进行严肃建模的开发者而言,交叉验证环节一直是 GPU 利用率的一个瓶颈。

在旧版本中,即使底层模型(如 XGBoost)支持 GPU 训练,Scikit-Learn 的元评估器(meta-estimators)也会在计算性能指标前,将中间数据搬回 CPU 并转换为 NumPy 数组。这种频繁的内存搬运和数据格式转换严重拖慢了整体流程。而 Array API 的支持,使得交叉验证的循环能够基本闭环在 GPU 内部运行,大幅减少开销。

如何启用与当前限制

要启用此特性,需要完成以下配置。如果缺少任何一步,程序将静默回退到传统的 NumPy 模式。

1. 设置环境变量(必须在导入 SciPy 或 Scikit-Learn 之前):

import os
os.environ["SCIPY_ARRAY_API"] = "1"

2. 配置 Scikit-Learn 内部开关

from sklearn import set_config
set_config(array_api_dispatch=True)

当前主要限制:暂不支持 cuDF DataFrames。你仍可以使用 cuDF 进行数据加载和预处理,但在输入模型前,必须确保数据是 array-like 格式(例如.values)。这意味着类别特征需要手动编码,无法再依赖 pandas/cuDF 的 dtype 自动识别。

实战:基于 GPU 的 XGBoost 交叉验证

以下是一个使用 CuPy 数组进行 5 折分层交叉验证的示例。为了确保整个链路停留在 GPU 上,我们需要对XGBClassifier进行简单封装,并配合 cuML 的指标计算。

import os
os.environ['SCIPY_ARRAY_API'] = '1'

import cupy as cp
import cudf
from sklearn.model_selection import StratifiedKFold, cross_val_score
from sklearn.metrics import make_scorer
from cuml.metrics import roc_auc_score
from xgboost import XGBClassifier
from sklearn import set_config
set_config(array_api_dispatch=True)

# 加载数据并进行简单的预处理
X = cudf.read_csv('/kaggle/input/playground-series-s5e12/train.csv').set_index('id')
y = X.pop('diagnosed_diabetes').astype(int)

# 类别特征编码处理
cat_cols = [c for c in X.columns if X[c].dtype == 'object']
X = X.astype({c: 'category' for c in cat_cols})
for c in cat_cols:
    X[c] = X[c].cat.codes

ft = ['c' if c in cat_cols else 'q' for c in X.columns]
kfold = StratifiedKFold(5, shuffle=True, random_state=0)

# 封装 XGB 以适配 CuPy 预测
class cuXGBClassifier(XGBClassifier):
    @property
    def classes_(self):
        return cp.asarray(super().classes_)
    def predict_proba(self, X):
        p = self.get_booster().inplace_predict(X)
        if p.ndim == 1:
            p = cp.column_stack([1 - p, p])
        return p
    def predict(self, X):
        return cp.asarray(super().predict(X))

model = cuXGBClassifier(
    enable_categorical=True,
    feature_types=ft,
    device='cuda',
    n_jobs=4,
    random_state=0
)

# 执行交叉验证
scores = cross_val_score(
    model,
    X.values,
    y.values,
    cv=kfold,
    scoring=make_scorer(
        roc_auc_score,
        response_method="predict_proba"
    ),
    n_jobs=1
)
print(f"{scores.mean():.5f} ± {scores.std():.5f}")

虽然代码需要一些适配工作,但它成功地将交叉验证的核心计算留在了 GPU 上,这对于处理大数据集时提升效率至关重要。

现阶段已支持的组件

目前 Array API 的覆盖范围正在逐步扩展。在 1.8.0 版本中,以下组件已具备较好的实验性支持:

  • 预处理StandardScaler, PolynomialFeatures
  • 线性模型与校准RidgeCV, RidgeClassifierCV, CalibratedClassifierCV
  • 聚类与混合模型GaussianMixture

官方示例显示,一个基于 PyTorch 张量的 Ridge 分类管道,在处理线性代数密集型任务时,在 Colab 环境下可比单核 CPU 实现快出近 10 倍。

ridge_pipeline_gpu = make_pipeline(
    feature_preprocessor,
    FunctionTransformer(
        lambda x: torch.tensor(
            x.to_numpy().astype(np.float32),
            device="cuda"
        )
    ),
    CalibratedClassifierCV(
        RidgeClassifierCV(alphas=alphas),
        method="temperature"
    ),
)

with sklearn.config_context(array_api_dispatch=True):
    cv_results = cross_validate(
        ridge_pipeline_gpu, features, target
    )

总结

Scikit-Learn 是否已经准备好完全接管 GPU 计算?答案是否定的,它仍在起步阶段。但 1.8.0 版本的意义在于,它朝着这个方向迈出了坚实的第一步。尽管当前的启用方式对普通用户还不够友好,略显“硬核”,但它为追求极致性能的开发者打开了一扇门,预示着未来 人工智能 和机器学习工作流中,设备异构计算将变得更加无缝和高效。




上一篇:Apache Flink 2.2 Delta Join原理详解:解决双流Join状态膨胀的实战指南
下一篇:CTF逆向工程:Python四种文件格式(pyc/exe/字节码/加花)的分析与反编译实战
您需要登录后才可以回帖 登录 | 立即注册

手机版|小黑屋|网站地图|云栈社区 ( 苏ICP备2022046150号-2 )

GMT+8, 2025-12-24 18:59 , Processed in 0.171294 second(s), 40 queries , Gzip On.

Powered by Discuz! X3.5

© 2025-2025 云栈社区.

快速回复 返回顶部 返回列表