scikit-learnで利用できる自作の回帰モデルの作成#

本稿では、ハイパーパラメータ探索の基本的な手法で利用したパイプライン( Pipeline )やハイパーパラメータ探索( GridSearchCV など)を始めとした、scikit-learnで提供されている各種機能を安全に利用できるような自作の回帰モデルを作成する方法を、公式ドキュメントのDeveloping scikit-learn estimatorsの内容に沿って解説します。なお、回帰ではなく分類をする分類器(Classifier)や、前処理をするTransformerの作成方法についても末尾の付録にて補足 していますので、そちらもご参照ください。

注意

本稿の説明はscikit-learnライブラリの v1.6.1 に準拠したものとなっています。

はじめに#

scikit-learnのパイプライン( Pipeline )やハイパーパラメータ探索( GridSearchCV など)といった各種機能と整合性のとれた回帰モデルを作成するためには、

  • 必要なメソッドが定義されている( fit() など)

  • 各メソッドの挙動がscikit-learnにおいて想定された範疇である

  • 不正な入力に対して、適切なエラーを発出する

などの条件を満たす必要があります。これらの条件は細かいものまで挙げると非常に多岐にわたりますが、自作の回帰モデルに適切な基底クラスを継承させたり、scikit-learnで用意されている関数を用いたりすることで比較的容易に満たすことができます。また、これらの条件が満たされているか否かを確認できるテスト関数 check_estimator() もscikit-learnでは用意されています。

本稿では以下の流れで自作の回帰モデルを作成する手順を解説していきます:

  1. クラスの作成

  2. コンストラクタ( __init__() )の実装

  3. fit()メソッドの実装

  4. predict() メソッドの実装

  5. テスト

最終的には、5.のテストにおいてテスト関数 check_estimator() に作成した回帰モデルを入力し、何もエラーが出なければ合格、すなわちscikit-learnの各種機能と整合性が取れているとみなします。

本稿では、最後に例として切片なしのリッジ回帰および常に平均値を返すベースラインモデルの実装例を紹介します。実装例を先にご覧になりたい方はこちらのリンクからご確認ください。

Step 0: コーディング規約#

まずは回帰モデルを作成する前にscikit-learn内のコーディング規約の内容を紹介します。scikit-learn公式ドキュメントのCoding guidelinesでは以下の内容に従うことが推奨されています:

  • PEP8に従う。

  • クラス以外の命名においてはアンダースコアで単語を区切る( nsamples ではなく n_samples とする)

  • iffor 文の後は改行する

  • scikit-learn内へは相対インポートを推奨(単体テストの場合を除く)

  • import * は絶対に使用しない

  • docstringのスタイルはnumpy docstring satandardに従う

また、アルゴリズムの実装においては下記の点にも注意するとよいでしょう:

  • 関数・変数名がscikit-learn内で用いられているものと重複しないように注意する。

  • 乱数を利用する場合は numpy.random.random() のような形式ではなく、 random_state 引数から numpy.random.RandomState オブジェクトを生成して乱数を生成する。その際は sklearn.utils.check_random_state() が便利です。

Step 1: クラスの作成#

クラスを定義する際、まずは sklearn.base.BaseEstimator を継承しましょう。これにはscikit-learnのモデル選択などの機能を利用するために必要なメソッドや、回帰モデルを含むさまざまな推定器を作成する際に便利なメソッドやが実装されています。例えば以下のようなメソッドが実装されています:

  • get_params() , set_params() : モデルのハイパーパラメータの取得・設定。モデル選択で回帰モデルを複製する際に用いられる。

  • _validate_params() : モデルのハイパーパラメータのバリデーションを行うことが可能。(Step 3-1で実際に使用する)

BaseEstimator に加えて、下記のクラスから対応するクラスを継承することで、作成したモデルが分類・回帰・クラスタリングのどれを行うモデルであるかを明示します:

クラス

種類

_estimator_type

定義されるメソッド

その他タグ

ClassifierMixin

分類

"classifier"

score() メソッドをaccuracyで定義

fit() 時に引数 y が必須になる

RegressorMixin

回帰

"regressor"

score() メソッドをr2で定義

fit() 時に引数 y が必須になる

ClusterMixin

クラスタリング

"clusterer"

fit_predict() の定義

上記のクラスは、継承することでクラス変数 _esimator_type に対応する文字列が格納されるほか、種類に応じて必要なメソッド( score() など)が定義されたり、 fit() 時に引数 y が必要であるか否かのタグ付けが自動で行われたりなどするので便利です。

Step 2: コンストラクタ( __init__() )の実装#

ポイント

  • 引数にはすべてデフォルト値を持たせる

  • 引数を同名のインスタンス変数にそのまま代入する処理のみを行う

コンストラクタではハイパーパラメータを引数から受け取り、それをそのままインスタンス変数に格納する処理のみを行います。言い換えると、それ以外の処理をコンストラクタ内で行ってはいけません。

具体的な注意点としては以下の通りです:

  1. 引数はすべてモデルのハイパーパラメータである(学習データは __init__() では受け取らない)

  2. 引数はすべてデフォルト値を持つキーワード引数である必要がある

  3. コンストラクタ内では、すべてのキーワード引数を同名のインスタンス変数に値を変えずに代入する処理のみを行う(バリデーションも行わない

  4. 変数名の末尾にアンダースコア( _ )を使用してはならない

  5. docstringを書く場合、引数をAttributesではなくParametersセクションに記載する

注釈

上記の注意点3.では重要なポイントが2つあります。1つ目は「インスタンス変数と引数の名前を一致させる」ことで、これはこの後のStep 3-1でパラメータのバリデーションを行う際に必要になるからです。2つ目は「値を変えずに代入する処理のみを行う」ことですが、これはscikit-learnのモデル選択機能との兼ね合いです。scikit-learnのモデル選択機能ではハイパーパラメータの設定それぞれに対応した回帰モデルを作成するのですが、その際回帰モデルを複製して新たなパラメータをそのまま代入する処理がなされます。このハイパーパラメータの設定ごとに回帰モデルを作成する処理とコンストラクタの処理を一致させるため、というのがコンストラクタで代入以外の処理を行ってはならない理由です。

注意点4.の「変数名の末尾にアンダースコア( _ )を使用してはならない」の理由は、変数名末尾のアンダースコアは fit() の中で定義・設定される変数であることを表す印であるためです。

コンストラクタの実装例は下の通りです:

def __init__(self, param1=1, param2=2):
    self.param1 = param1
    self.param2 = param2

逆に、下のような実装は誤りです:

# [NG] param1にデフォルト値が設定されていない
def __init__(self, param1, param2=1.0, param3=1.0):
    # [NG] コンストラクタ内で代入以外の処理を行っている
    if param1 < 0.0:
        raise ValueError(f'param1 must be non-negative.')

    self.param1 = param1
    self.param2_ = param2  # [NG] 引数とインスタンス変数の名前が異なる
    self.param3 = abs(param3)  # [NG] 引数から値を変えて代入している

Step 3: fit() メソッドの実装#

fit() メソッドでは訓練データ X , y を受け取り、モデルの学習をしたのちに自身( self )を返り値として返します。

引数としては

  • 説明変数 X (形状が (サンプル数, 特徴量数) の2次元配列)

  • 目的変数 y (形状が(サンプル数, )の1次元配列)

    • 教師なし学習の場合も、 y=None をデフォルト値とする必要がある

  • その他、データに依存するパラメータ(例: グラム行列、affinity matrix)

を受け取ります。ここで、 X , y 以外の引数としてデータに依存しないパラメータ、例えばハイパーパラメータや反復アルゴリズムにおける停止基準 tolfit() の引数にしてはなりません。そのようなデータに依存しないパラメータはコンストラクタの方で引数として渡すことが推奨されます。

なお、教師なし学習の場合でも y=None を引数とするのは、パイプラインで教師あり・教師なしを統一的に取り扱うためです。( fit_predict() , fit_transform() , score() , partial_fit() でも同様)

fit() メソッドは以下の性質を持つように実装する必要があります:

  • Xy のサンプル数が一致しない場合、 ValueError を発出する(後述するバリデーション方法を用いれば、自動でエラー発出まで行うことが可能)

  • fit() を複数回呼んだ際の結果は、一番最後の fit() のみに依存する

    • ただし、アルゴリズム内で乱数を用いる場合や、前回の学習で得たパラメータの値を再利用する warm_start を行う場合を除く

本節では、 fit() メソッドで実装すべき処理について、下記の順番で説明していきます:

  1. パラメータのバリデーション

  2. データ X , y のバリデーション

  3. 学習アルゴリズム

  4. 終了時の処理

Step 3-1: パラメータのバリデーション#

ポイント

  • 辞書型のクラス変数 _parameter_constraints を定義する

  • BaseEstimator._validate_params() メソッドを呼んでパラメータのバリデーションを行う

ここでは fit() メソッドの中で実際に学習する処理を行う前に、モデルのハイパーパラメータが不正な値でないかの確認(バリデーション)を行います。例えばscikit-learnで実装されているリッジ回帰やLASSOにおける正則化係数 alpha のように非負の値である必要のある引数に負の値が渡された際にエラーを発出する、というのがバリデーションです。

バリデーションは手動で行ってもよいですが、ここでは BaseEstimator_validate_params() メソッドを用いたバリデーションを紹介します。 _validate_params() メソッドは、辞書型のクラス変数 _parameter_constraints の内容に従って自動でバリデーションを行い、不正な値があれば内容に応じたエラーを出すメソッドです。 _validate_params() を呼ぶ際は _validate_params() を直接呼ぶか、あるいは fit() をデコレータ @_fit_context(prefer_skip_nested_validation=True) でラップすることで自動的に冒頭で呼ばれるようにするとよいでしょう。

なお、この方法でパラメータのバリデーションを行う際はクラス変数 _parameter_constraints にパラメータの制約を記載する必要があります。この _parameter_constraints の書き方については付録をご参照ください。

注釈

_fit_context() の引数 prefer_skip_nested_validationTrue にすると、回帰モデル内でさらに別の回帰モデルを利用している場合に内側の回帰モデルのパラメータバリデーションを省略します。これにより、すでに問題ないと判断されたパラメータに対し何度もバリデーションが行われることを防げます。そのため基本的には prefer_skip_nested_validation=True として問題ありません。例外的に prefer_skip_nested_validation=False とする場面は、バリデーションされていない回帰モデルをパラメータとして受け取るメタ推定器を定義する場合などです。

パラメータが特定の型であるかや、特定の範囲の整数・実数であるかといった比較的単純なルールだけであればこの方法で十分ですが、より複雑なバリデーションが必要な場合は _validate_params() の後に追加でバリデーション処理を行うとよいでしょう。

Step 3-2: データ X , y のバリデーション#

ポイント

X , y のバリデーションは、 sklearn.utils.validation 内の validate_data() 関数を使うと楽に行えます。この関数では、 XyNaN が含まれないかの確認などを行い、返り値として numpy.ndarray 型になった X , y を返します。なお、この関数では引数を工夫することで、逆に NaN を含むデータを許容するなどさまざまな条件でのバリデーションを行うことができます。詳細は公式ドキュメントの validate_data() のページをご参照ください。

また、 fit() における validate_data() の引数 reset はデフォルト値の reset=True に設定することが推奨されます。このように設定すると validate_data() の中で特徴量の個数を表す n_features_in_ 属性の定義もしてくれるようになります。なおこの n_features_in_ 属性は、後の predict()transform() 実行時のバリデーションのために利用されます。

注釈

学習データのバリデーションだけであれば check_array()check_X_y() でも行うことができますし、場合によっては np.asarray() でも十分なこともあります。ただし、これらの方法だけでは上で述べた n_features_in_ 属性の定義を行ってくれないため、 n_features_in_ 属性を定義する処理を別途実装する必要がある点にご注意ください。

※なお、 np.asanyarray() , np.atleast_2d() は非推奨

Step 3-3: 学習アルゴリズム#

ポイント

  • モデル固有の学習処理を実装する

  • 学習したパラメータや反復回数を格納する変数の命名規則に注意する

この部分で、モデル固有の学習処理を実装します。基本的には、データからモデルのパラメータを推定してそれを格納する部分の処理を行います。

ここで定義する変数名については、以下の点に注意してください:

  • 学習したパラメータを格納するインスタンス変数名は、末尾にアンダースコアを1つつける(例: coef_ , intercept_

  • 反復アルゴリズムでは、反復回数(整数)を n_iter で管理する

Step 3-4: 終了時の処理#

ポイント

  • fit() メソッドを実行済みかを管理するフラグ _is_fitted を定義する(任意)

  • 返り値として自身( self )を返す

scikit-learnでは、 fit() メソッドを実行済みかを管理しており、例えば fit() を実行せずに predict() を実行しようとした際にエラーを発出する必要があります。 fit() メソッドを実行済みかを判断する方法はStep 4-1にて説明するように複数存在しますが、本稿では fit() 済であるかを確認するためのフラグ _is_fitted を利用することとします。

最後に、 fit() の返り値を自身( self )として fit() の処理は完了です。

Step 4: predict() メソッドの実装#

predict() メソッドでは、 fit() メソッドで計算したパラメータをもとに予測値を算出します。なお、 predict() メソッド実行中に属性を変更してはならないという制約がありますので、その点にご注意ください。

本節では、 predict() メソッドで実装すべき処理について、下記の順番で説明していきます:

  1. fit() メソッドが実行済みか確認

  2. データ X のバリデーション

  3. 予測値の算出・出力

Step 4-1: fit() メソッドが実行済みか確認#

ポイント

  • check_is_fitted() 関数で fit() メソッドが実行済みであるか確認する

predict() メソッドは fit() を実行した後に呼ばれる必要があるメソッドであることから、 fit() をまだ実行していない状態で predict() を実行した際に特定のエラー( NotFittedError )を発出する必要があります。実際には predict() メソッドの冒頭で、 fit() メソッドが実行済みであるか否かをブール値で返す sklearn.utils.validation.check_is_fitted() を呼ぶことになります。

check_is_fitted() のアルゴリズムは以下の通りです:

  1. attributes に引数名(のリスト)を渡した際は、それらの属性がすべて定義されているか否かで fit() 実行済みか判断する

    • 例:attributes=['coef_'] とした場合は、 coef_ 属性が定義されていれば fit() 実行済みと判断される

    • all_or_any=any を追加した場合は、 attributes の引数名のうちいずれか1つが存在すれば fit() 実行済みと判断される

  2. __sklearn_is_fitted__() メソッドがあれば、それを実行する

  3. 上記1., 2.のいずれでもないなら、末尾にアンダースコア(_)の付いた属性が1つでも存在すれば fit() 実行済みと判断する

詳細は __sklearn_is_fitted__ as Developer API もご参照ください。

この後紹介する実行例では上の2.の方式を採用し、 __sklearn_is_fitted__() メソッドを定義してその中でStep 3-4で用意した _is_fitted のフラグを確認することにします。

Step 4-2: データ X のバリデーション#

ポイント

  • fit() 時と同様に、 sklearn.utils.validation.validate_data() 関数でデータのバリデーションを行う

ここは fit() における X , y のバリデーションと同様です。ただし、

  • 予測時には説明変数のデータ X のみを受け取るため引数 y は指定しない

  • 引数 resetreset=False に設定する

という点のみ異なります。

Step 4-3: 予測値の算出・出力#

ポイント

  • 予測値を計算し、予測値を返り値とする

ここでは与えられたデータ X について予測値を算出し、それを返します。その際、本節冒頭でも述べたように属性を変更してはならない点にご注意ください。

Step 5: テスト#

ポイント

  • check_estimator() 関数に作成した回帰モデルを渡し、エラーが出ないか確認する

このステップでは最後に、作成したクラスがscikit-learnの推定器として必要な機能を備えているかを確認します。確認にはテスト用の sklearn.utils.estimator_checks.check_estimator() 関数を使うのが簡単です。この関数では推定器のインスタンスを受け取ると

  • 変数名や各メソッド内の処理に非推奨事項が無いか

  • 特定の状況で特定のエラーが発出されるか

  • 線形モデルから生成されたデータに対する予測精度が著しく低くなっていないか

をテストし、その結果をリストで返します。例えば、作成した MyEstimator クラスをテストしたい場合は check_estimator(MyEstimator()) を実行してエラーが出なければ合格です。エラーが出た場合はエラーメッセージに従ってコードを修正します。

ただし、このテストはあくまでscikit-learnの推定器として必要なインターフェースが整っているかの確認であり、学習・予測アルゴリズムの実装が正しいかの確認までは行えない点にご注意ください。

check_estimator() のほかにも sklearn.utils にはテストの際に便利な関数が実装されています。例えば値が近いか判定する sklearn.utils._testing.assert_allclose() 関数などです。

実装例#

最後に、本稿の内容に従って実装した推定器をいくつか紹介します。内容の理解の一助として頂ければ幸いです。

例1: リッジ回帰(切片なし)#

リッジ回帰自体は sklearn.linear_model.Ridge で実装されていますが、本稿の内容への理解を深めるための例としてここで紹介します。なお、簡単のために切片はないものとします。

リッジ回帰の詳細についてはスパースモデリング(基本編)#RIDGEをご参照ください。

# scikit-learnのバージョン確認
import sklearn

print(sklearn.__version__)
1.6.1
import numpy as np
from sklearn.base import BaseEstimator, RegressorMixin, _fit_context
from sklearn.utils.validation import validate_data, check_is_fitted
from sklearn.utils.estimator_checks import check_estimator
from sklearn.utils._param_validation import Interval
from numbers import Real  # Intervalの引数用。


class MyRidge(RegressorMixin, BaseEstimator):
    # [Step 3-1] パラメータのバリデーション用
    _parameter_constraints = {
        # alphaはNoneまたは0以上の実数である、の意
        "alpha": [None, Interval(type=Real, left=0, right=None, closed="left")],
    }

    def __init__(
        self,
        alpha: float = 0.2,
    ):
        self.alpha = alpha

    # [Step 3-1] パラメータのバリデーション(_fit_contextによるもの)
    @_fit_context(prefer_skip_nested_validation=True)
    def fit(self, X, y):
        # 追加のバリデーション処理が必要ならばここに記載

        # [Step 3-2] 入力データのバリデーション
        X, y = validate_data(self, X, y)

        # [Step 3-3] 学習アルゴリズム(手法により変化)
        coef = (
            np.linalg.inv(X.T @ X + self.alpha * np.eye(self.n_features_in_)) @ X.T @ y
        )

        # [POINT]fitの中で計算される属性名は末尾に_を付ける。
        self.coef_ = coef

        # [Step 3-4] 終了時の処理
        self._is_fitted = True  # <- predict冒頭で、fit済みか確認するためのフラグ

        return self

    def predict(self, X):
        # [Step 4-1] predict冒頭でfit済みかを確認。
        check_is_fitted(self)  # この中で下の__sklearn_is_fitted__が呼ばれる

        # [Step 4-2] 入力データのバリデーション
        X = validate_data(self, X, reset=False)

        # [Step 4-3] 予測値の算出・出力
        return X @ self.coef_

    # [Step 4-1] fit済みか確認するためのメソッド定義
    def __sklearn_is_fitted__(self):
        return hasattr(self, "_is_fitted") and self._is_fitted


# [Step 5] テスト
result_regr1 = check_estimator(MyRidge())  # エラーが出なければ合格

上の実装例の _parameter_constraints で利用している Interval インスタンスは、引数がある区間に属するという制約を意味します。詳細は付録の_parameter_constraints の書き方をご参照ください。

例2: 常に目的変数の平均値を予測値として返すベースライン推定器#

次に、常に訓練データにおける目的変数の平均値を予測値として返すベースライン推定器の実装例を紹介します。

class BaselineMean(RegressorMixin, BaseEstimator):
    _parameter_constraints = {}

    def __init__(self):
        pass

    @_fit_context(prefer_skip_nested_validation=True)
    def fit(self, X, y):
        X, y = validate_data(self, X, y)
        self.mean_ = y.mean()

        self._is_fitted = True
        return self

    def predict(self, X):
        check_is_fitted(self)
        X = validate_data(self, X, reset=False)

        return np.full(shape=(X.shape[0]), fill_value=self.mean_)

    def __sklearn_is_fitted__(self):
        return hasattr(self, "_is_fitted") and self._is_fitted

    # [POINT] 予測精度が悪くてもテスト関数をパスできるように設定
    def __sklearn_tags__(self):
        tags = super().__sklearn_tags__()
        tags.regressor_tags.poor_score = True
        return tags


result_regr2 = check_estimator(BaselineMean())

注釈

scikit-learnでは推定器の機能(疎行列に対応しているか否か、どの形式の出力に対応しているかなど)を管理するタグが存在しています。これらのタグは主に、 check_estimator() などのテスト関数において、どのテストを行うべきか(行わなくてよいか)を判断する際に使用されます。

check_estimator() では線形モデルから生成したデータに対して予測精度(決定係数)が悪すぎないかのテストがあり、精度が悪ければエラーとなります。ですが上のベースライン推定器の場合、精度が悪いのは想定された挙動であるため、精度が悪いのが原因でテスト関数をパスできないのは好ましくありません。そこでこの例2ではこの機能を用いて poor_score タグを True に設定することにより、予測精度が悪くてもテストをパスできるように設定しています。

タグの変更は上の実装例のように __sklearn_tags__() をオーバーライドし、その中でタグを設定することで行うことが可能です。各タグの意味などの詳細については公式ドキュメントのEstimator Tagsの項をご参照ください。

参考文献#

付録#

分類器(Classifier)の実装#

分類器(Classifier)も、回帰モデルとほぼ同様の手順で実装できます。ただし、回帰モデルの手順に加えて

  • classes_ 属性に各内部ラベルに対応する y の値を格納する

  • 目的変数 y の値が連続値でないことを確認する

    • sklearn.utils.multiclass.check_classification_targets を利用するとよい

という手順が必要である点にご注意ください。(加えて、 predict_proba()predict_log_proba() , decision_function() を定義する場合はそれらとの整合性が取れている必要もあります。)

# 訓練データの最初のサンプルの目的変数を常に返す分類器
from sklearn.base import ClassifierMixin
from sklearn.utils.multiclass import check_classification_targets


class FirstClassifier(ClassifierMixin, BaseEstimator):
    def fit(self, X, y):
        X, y = validate_data(self, X, y)
        check_classification_targets(y)
        self.classes_ = np.unique(y)

        self._is_fitted = True
        return self

    def predict(self, X):
        check_is_fitted(self)
        X = validate_data(self, X, reset=False)

        return np.full(shape=(X.shape[0],), fill_value=self.classes_[0])

    def __sklearn_is_fitted__(self):
        return hasattr(self, "_is_fitted") and self._is_fitted

    def __sklearn_tags__(self):
        tags = super().__sklearn_tags__()
        tags.classifier_tags.poor_score = True
        return tags


result_trans = check_estimator(FirstClassifier())

前処理(Transformer)の実装#

scikit-learnにおけるTransformerとは、標準化・主成分分析といったデータの変換や前処理を行うオブジェクトのことです。(Transformerというと同名の深層学習モデルが存在しますが、それとは別の概念ですのでご注意ください。)

Transformerを定義する際は TransformerMixin を継承し、 fit()transform() メソッドを実装することになります。 TransformerMixin を継承すると、

  • fit_transform() メソッド

  • set_output() メソッド(条件あり・後述)

メソッドが利用できるようになります。

Transformerにおける fit() の実装方法は回帰モデルの場合と基本的には同様です。ただし Pipeline との兼ね合いで、引数として実際には使用しない y (デフォルト値 None )も記載する必要がある点にご注意ください。 transform() メソッドの実装方法は回帰モデルにおける predict() メソッドの時とほぼ同様で、同じくデータ X を引数とします。ただし返り値が X と同じサンプル数を持つ2次元配列であるという点が異なります。加えてその際、入力と出力で配列のサンプル数を変えたりサンプルの順番を入れ替えてはならない、という点にもご注意ください。

set_output() はTransformerの変換結果を numpy.ndarray 以外の形式、例えば pandas.DataFrame などで返せるように設定するメソッドで、この機能は get_feature_names_out() を定義することで利用可能となります。 get_feature_names_out() は下記のいずれかのクラスを継承することで実装可能です:

詳細な使い方はIntroducing the set_output APIをご参照ください。

# 入力Xを何も変換せずにそのまま返すTransformerの実装例
from sklearn.base import TransformerMixin, OneToOneFeatureMixin


class IdentityTransformer(TransformerMixin, OneToOneFeatureMixin, BaseEstimator):

    def fit(self, X, y=None):
        X = validate_data(self, X)
        self._is_fitted = True
        return self

    def transform(self, X, y=None):
        check_is_fitted(self)
        X = validate_data(self, X, reset=False)
        return X

    def __sklearn_is_fitted__(self):
        return hasattr(self, "_is_fitted") and self._is_fitted


result_trans = check_estimator(IdentityTransformer())

_parameter_constraints の書き方#

_parameter_constraints はキーを変数名(str)、値が制約のリストとなる辞書として定義します。値として複数の要素を持つリストを設定した場合、それらの要素のうちいずれかの条件に合致すれば問題なしと判断されます。なお、キーに含まれないパラメータ、あるいは値を 'no_validation' にしたパラメータには何の制約も課されません。

制約の例は以下の通りです:

_parameter_constraints = {
    'p0': 'no_validation',  # <- p0は何でもよい。
    'p1': ['array-like'],  # <- p1は配列である。
    'p2': [callable],  # <- p2はcallable(関数など)である。
    'p3': [None],  # <- p3はNoneである。
    'p4': [list],  # <- p4はlistである。
    'p5': ['boolean'],  # <- p5はbool値である。
    'p6': ['nan'],  # <- p6はNaNである。
    'p7': [list, None]  # <- p7はlistまたはNoneである。
}

制約として利用可能なオブジェクトは下の表のとおりです:

constraint

許容される値

任意の型

その型のオブジェクト(その型を継承していてもよい)

callable

callableなオブジェクト(関数など)

"array-like"

リストや numpy.ndarray , pandas.DataFrame , pandas.Series など

"sparse matrix"

scipy.sparse の疎行列

"random_state"

None 、非負整数あるいは numpy.random.RandomState インスタンス

"boolean"

bool値(True , False

"verbose"

verbose 引数に入るような値(非負実数やbool値)

"cv_object"

交差検証の方法を指定する cv 引数に入るような値(2以上の整数および split()get_n_splits() メソッドを持つCV splitterなど)

"nan"

numpy.nan

None

None

Intervalオブジェクト

区間上の整数または実数

Optionsオブジェクト

指定した型の、指定した値の中のいずれか

StrOptionsオブジェクト

指定した文字列の中のいずれか

MissingValuesオブジェクト

欠損値処理において、欠損箇所に埋めるような値(整数・実数・ nan など)

HasMethodsオブジェクト

指定したメソッドがすべて実装されているオブジェクト

Hiddenオブジェクト

ユーザーに秘匿したい制約に対して使用

なお、表の Interval , Options , StrOptions , MissingValues , HasMethods , Hiddensklearn.utils._param_validation 内で定義されているクラスです。

詳細については、実際にバリデーションを行う関数である sklearn.utils._param_validation.validate_parameter_constraints() や表中のオブジェクトのdocstringの方もご参照ください:

# docstringの確認方法の一例
from sklearn.utils._param_validation import validate_parameter_constraints

# from sklearn.utils._param_validation import Interval  # Intervalは上の実装例1でインポート済み

print("[validate_parameter_constraintsのdocstring]")
print(validate_parameter_constraints.__doc__)

print("[Intervalのdocstring]")
print(Interval.__doc__)
[validate_parameter_constraintsのdocstring]
Validate types and values of given parameters.

    Parameters
    ----------
    parameter_constraints : dict or {"no_validation"}
        If "no_validation", validation is skipped for this parameter.

        If a dict, it must be a dictionary `param_name: list of constraints`.
        A parameter is valid if it satisfies one of the constraints from the list.
        Constraints can be:
        - an Interval object, representing a continuous or discrete range of numbers
        - the string "array-like"
        - the string "sparse matrix"
        - the string "random_state"
        - callable
        - None, meaning that None is a valid value for the parameter
        - any type, meaning that any instance of this type is valid
        - an Options object, representing a set of elements of a given type
        - a StrOptions object, representing a set of strings
        - the string "boolean"
        - the string "verbose"
        - the string "cv_object"
        - the string "nan"
        - a MissingValues object representing markers for missing values
        - a HasMethods object, representing method(s) an object must have
        - a Hidden object, representing a constraint not meant to be exposed to the user

    params : dict
        A dictionary `param_name: param_value`. The parameters to validate against the
        constraints.

    caller_name : str
        The name of the estimator or function or method that called this function.
    
[Intervalのdocstring]
Constraint representing a typed interval.

    Parameters
    ----------
    type : {numbers.Integral, numbers.Real, RealNotInt}
        The set of numbers in which to set the interval.

        If RealNotInt, only reals that don't have the integer type
        are allowed. For example 1.0 is allowed but 1 is not.

    left : float or int or None
        The left bound of the interval. None means left bound is -∞.

    right : float, int or None
        The right bound of the interval. None means right bound is +∞.

    closed : {"left", "right", "both", "neither"}
        Whether the interval is open or closed. Possible choices are:

        - `"left"`: the interval is closed on the left and open on the right.
          It is equivalent to the interval `[ left, right )`.
        - `"right"`: the interval is closed on the right and open on the left.
          It is equivalent to the interval `( left, right ]`.
        - `"both"`: the interval is closed.
          It is equivalent to the interval `[ left, right ]`.
        - `"neither"`: the interval is open.
          It is equivalent to the interval `( left, right )`.

    Notes
    -----
    Setting a bound to `None` and setting the interval closed is valid. For instance,
    strictly speaking, `Interval(Real, 0, None, closed="both")` corresponds to
    `[0, +∞) U {+∞}`.