Bases: Search
A class which executes a grid search.
Grid search can be used to find the optimal combination of one or more hyperparameters.
search = GridSearch(score_fn=lambda search_idx, a, b: a + b, params={"a": [1, 2, 3], "b": [4, 5, 6]})
search.fit()
print(search.get_best_parameters()) # {"a": 3, "b": 6, "search_idx": 9}
Parameters:
Name |
Type |
Description |
Default |
score_fn |
Callable[..., float]
|
Objective function that measures search fitness. One of its arguments must be 'search_idx' which will
be automatically provided by the search routine. This can help with file saving / logging during the search. |
required
|
params |
Dict[str, List]
|
A dictionary with key names matching the score_fn 's inputs. Its values should be lists of options. |
required
|
best_mode |
str
|
Whether maximal or minimal fitness is desired. Must be either 'min' or 'max'. |
'max'
|
name |
str
|
The name of the search instance. This is used for saving and loading purposes. |
'grid_search'
|
Raises:
Type |
Description |
AssertionError
|
If params is not dictionary, or contains key not used by score_fn |
Source code in fastestimator\fastestimator\search\grid_search.py
| class GridSearch(Search):
"""A class which executes a grid search.
Grid search can be used to find the optimal combination of one or more hyperparameters.
```python
search = GridSearch(score_fn=lambda search_idx, a, b: a + b, params={"a": [1, 2, 3], "b": [4, 5, 6]})
search.fit()
print(search.get_best_parameters()) # {"a": 3, "b": 6, "search_idx": 9}
```
Args:
score_fn: Objective function that measures search fitness. One of its arguments must be 'search_idx' which will
be automatically provided by the search routine. This can help with file saving / logging during the search.
params: A dictionary with key names matching the `score_fn`'s inputs. Its values should be lists of options.
best_mode: Whether maximal or minimal fitness is desired. Must be either 'min' or 'max'.
name: The name of the search instance. This is used for saving and loading purposes.
Raises:
AssertionError: If `params` is not dictionary, or contains key not used by `score_fn`
"""
def __init__(self,
score_fn: Callable[..., float],
params: Dict[str, List],
best_mode: str = "max",
name: str = "grid_search"):
assert isinstance(params, dict), "must provide params as a dictionary"
score_fn_args, params_args = set(inspect.signature(score_fn).parameters.keys()), set(params.keys())
assert score_fn_args.issuperset(params_args), "unused param {} in score_fn".format(params_args - score_fn_args)
super().__init__(score_fn=score_fn, best_mode=best_mode, name=name)
self.params = params
def _fit(self):
experiments = (dict(zip(self.params, x)) for x in itertools.product(*self.params.values()))
for exp in experiments:
self.evaluate(**exp)
best_results = self.get_best_results()
print("FastEstimator-Search: Grid Search Finished, best parameters: {}, best score: {}".format(
best_results[0], best_results[1]))
|