Skip to contents

Many machine learning systems determine hard class predictions by first predicting the probability of an event and then predicting that an event will occur if its respective probability is above 0.5. This adjustment allows practitioners to determine hard class predictions using a threshold other than 0.5. By setting appropriate thresholds, one can balance the trade-off between different types of errors (such as false positives and false negatives) to optimize the model's performance for specific use cases.

Usage

adjust_probability_threshold(x, threshold = 0.5)

Arguments

x

A tailor().

threshold

A numeric value (between zero and one) or hardhat::tune().

Data Usage

This adjustment doesn't require estimation and, as such, the same data that's used to train it with fit() can be predicted on with predict(); fitting this adjustment just collects metadata on the supplied column names and does not risk data leakage.

Examples

library(modeldata)

# `predicted` gives hard class predictions based on probability threshold .5
head(two_class_example)
#>    truth      Class1       Class2 predicted
#> 1 Class2 0.003589243 0.9964107574    Class2
#> 2 Class1 0.678621054 0.3213789460    Class1
#> 3 Class2 0.110893522 0.8891064779    Class2
#> 4 Class1 0.735161703 0.2648382969    Class1
#> 5 Class2 0.016239960 0.9837600397    Class2
#> 6 Class1 0.999275071 0.0007249286    Class1

# use a threshold of .1 instead:
tlr <-
  tailor() %>%
  adjust_probability_threshold(.1)

# fit by supplying column names. situate in a modeling workflow
# with `workflows::add_tailor()` to avoid having to do so manually
tlr_fit <- fit(
  tlr,
  two_class_example,
  outcome = c(truth),
  estimate = c(predicted),
  probabilities = c(Class1, Class2)
)

# adjust hard class predictions
predict(tlr_fit, two_class_example) %>% head()
#> # A tibble: 6 × 4
#>   truth   Class1   Class2 predicted
#>   <fct>    <dbl>    <dbl> <fct>    
#> 1 Class2 0.00359 0.996    Class2   
#> 2 Class1 0.679   0.321    Class1   
#> 3 Class2 0.111   0.889    Class1   
#> 4 Class1 0.735   0.265    Class1   
#> 5 Class2 0.0162  0.984    Class2   
#> 6 Class1 0.999   0.000725 Class1