Skip to content

grid_search

GridSearch

Bases: Search

A class which executes a grid search.

Grid search can be used to find the optimal combination of one or more hyperparameters.

search = GridSearch(score_fn=lambda search_idx, a, b: a + b, params={"a": [1, 2, 3], "b": [4, 5, 6]})
search.fit()
print(search.get_best_parameters()) # {"a": 3, "b": 6, "search_idx": 9}

Parameters:

Name Type Description Default
score_fn Callable[..., float]

Objective function that measures search fitness. One of its arguments must be 'search_idx' which will be automatically provided by the search routine. This can help with file saving / logging during the search.

required
params Dict[str, List]

A dictionary with key names matching the score_fn's inputs. Its values should be lists of options.

required
best_mode str

Whether maximal or minimal fitness is desired. Must be either 'min' or 'max'.

'max'
name str

The name of the search instance. This is used for saving and loading purposes.

'grid_search'

Raises:

Type Description
AssertionError

If params is not dictionary, or contains key not used by score_fn

Source code in fastestimator\fastestimator\search\grid_search.py
class GridSearch(Search):
    """A class which executes a grid search.

    Grid search can be used to find the optimal combination of one or more hyperparameters.

    ```python
    search = GridSearch(score_fn=lambda search_idx, a, b: a + b, params={"a": [1, 2, 3], "b": [4, 5, 6]})
    search.fit()
    print(search.get_best_parameters()) # {"a": 3, "b": 6, "search_idx": 9}
    ```

    Args:
        score_fn: Objective function that measures search fitness. One of its arguments must be 'search_idx' which will
            be automatically provided by the search routine. This can help with file saving / logging during the search.
        params: A dictionary with key names matching the `score_fn`'s inputs. Its values should be lists of options.
        best_mode: Whether maximal or minimal fitness is desired. Must be either 'min' or 'max'.
        name: The name of the search instance. This is used for saving and loading purposes.

    Raises:
        AssertionError: If `params` is not dictionary, or contains key not used by `score_fn`
    """
    def __init__(self,
                 score_fn: Callable[..., float],
                 params: Dict[str, List],
                 best_mode: str = "max",
                 name: str = "grid_search"):
        assert isinstance(params, dict), "must provide params as a dictionary"
        score_fn_args, params_args = set(inspect.signature(score_fn).parameters.keys()), set(params.keys())
        assert score_fn_args.issuperset(params_args), "unused param {} in score_fn".format(params_args - score_fn_args)
        super().__init__(score_fn=score_fn, best_mode=best_mode, name=name)
        self.params = params

    def _fit(self):
        experiments = (dict(zip(self.params, x)) for x in itertools.product(*self.params.values()))
        for exp in experiments:
            self.evaluate(**exp)
        best_results = self.get_best_results()
        print("FastEstimator-Search: Grid Search Finished, best parameters: {}, best score: {}".format(
            best_results[0], best_results[1]))