ProbabilityDistributionEditor.java
package org.opentrafficsim.swing.gui;
import java.awt.FontMetrics;
import java.awt.Graphics;
import java.awt.Graphics2D;
import java.awt.RenderingHints;
import java.awt.geom.Rectangle2D;
import java.util.ArrayList;
import java.util.Hashtable;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Objects;
import java.util.function.BiFunction;
import java.util.stream.IntStream;
import javax.swing.JLabel;
import javax.swing.event.ChangeEvent;
import javax.swing.event.ChangeListener;
import org.djutils.exceptions.Throw;
import org.djutils.swing.multislider.LinearMultiSlider;
/**
* Editor for a distribution of probabilities of all possible categories. The probabilities must sum to 1.0.
* <p>
* Copyright (c) 2024-2024 Delft University of Technology, PO Box 5, 2600 AA, Delft, the Netherlands. All rights reserved. <br>
* BSD-style license. See <a href="https://opentrafficsim.org/docs/license.html">OpenTrafficSim License</a>.
* </p>
* @author <a href="https://github.com/averbraeck">Alexander Verbraeck</a>
* @author <a href="https://github.com/peter-knoppers">Peter Knoppers</a>
* @author <a href="https://github.com/wjschakel">Wouter Schakel</a>
* @param <T> category type
*/
public class ProbabilityDistributionEditor<T> extends LinearMultiSlider<Double>
{
/** */
private static final long serialVersionUID = 20250916L;
/** Number of values the slider allows per percent. */
private final int valuesPerPercent;
/** Categories. */
private final List<T> categories;
/** Label function. */
private BiFunction<T, Double, String> labelFunction = (t, p) -> String.format("%s: %.1f%%", t, p * 100.0);
/** Category font size. */
private float categoryFontSize = 10.0f;
/**
* Constructor.
* @param categories categories
* @param probabilities probabilities
* @param valuesPerPercent number of values the slider allows per percent
*/
public ProbabilityDistributionEditor(final List<T> categories, final double[] probabilities, final int valuesPerPercent)
{
super(0.0, 1.0, 100 * valuesPerPercent + 1, checkValues(probabilities));
Throw.whenNull(categories, "categories");
Throw.whenNull(probabilities, "probabilities");
Throw.when(categories.size() != new LinkedHashSet<>(categories).size(), IllegalArgumentException.class,
"The categories are not unique.");
Throw.when(valuesPerPercent < 1, IllegalArgumentException.class, "valuesPerPercent should be at least 1.");
this.categories = new ArrayList<>(categories);
this.valuesPerPercent = 100 * valuesPerPercent;
// create default %-labels, although not shown by default
setLabelTable(new Hashtable<Integer, JLabel>(IntStream.range(0, 11).collect(() -> new LinkedHashMap<>(),
(m, i) -> m.put(i * 10 * valuesPerPercent, new JLabel((i * 10) + "%")), (m1, m2) -> m1.putAll(m2))));
setMajorTickSpacing(5 * valuesPerPercent);
setPaintTicks(true);
setOverlap(true);
setPaintTrack(false);
this.addChangeListener(new ChangeListener()
{
@Override
public void stateChanged(final ChangeEvent e)
{
// need to update the labels as we drag, otherwise the thumbs erase (part of) the labels
repaint();
}
});
}
/**
* Check values are positive and add up to one.
* @param probabilities probabilities
* @return cumulative result of probabilities
*/
private static Double[] checkValues(final double[] probabilities)
{
Double[] out = new Double[probabilities.length - 1];
double cumul = 0.0;
for (int i = 0; i < probabilities.length - 1; i++)
{
Throw.when(probabilities[i] < 0.0, IllegalArgumentException.class, "Probabilities should not be negative.");
cumul += probabilities[i];
out[i] = cumul;
}
Throw.when(Math.abs(cumul + probabilities[probabilities.length - 1] - 1.0) > 1e-9, IllegalArgumentException.class,
"Probabilities do not add up to one.");
return out;
}
@Override
protected Double mapIndexToValue(final int index)
{
return ((double) index) / this.valuesPerPercent;
}
/**
* Sets the label function. This function receives the category object and the probability in the normalized [0...1] range.
* @param labelfunction label function receiving the category object and the probability in the normalized [0...1] range
*/
public void setCategoryLabelFunction(final BiFunction<T, Double, String> labelfunction)
{
Throw.whenNull(labelfunction, "labelfunction");
this.labelFunction = labelfunction;
}
/**
* Set the font size for the category labels.
* @param categoryFontSize font size for the category labels
*/
public void setCategoryFontSize(final float categoryFontSize)
{
Throw.when(categoryFontSize <= 0.0, IllegalArgumentException.class, "Category font size should be larger than 0.");
this.categoryFontSize = categoryFontSize;
}
/** {@inheritDoc} */
@Override
public void paint(final Graphics g)
{
super.paint(g);
Graphics2D g2 = (Graphics2D) g;
// edges between probability bands
int[] edges = new int[getNumberOfThumbs() + 2];
edges[0] = getTrackSizeLoPx();
for (int i = 0; i < getNumberOfThumbs(); i++)
{
edges[i + 1] = thumbPositionPx(i);
}
edges[edges.length - 1] = getTrackSizeLoPx() + trackSize();
// draw labels
g2.setRenderingHint(RenderingHints.KEY_ANTIALIASING, RenderingHints.VALUE_ANTIALIAS_ON);
g2.setFont(g2.getFont().deriveFont(this.categoryFontSize));
FontMetrics fontMetric = g2.getFontMetrics();
for (int i = 0; i < edges.length - 1; i++)
{
int x = (edges[i] + edges[i + 1]) / 2;
String categoryLabel;
try
{
categoryLabel = this.labelFunction.apply(this.categories.get(i), getProbability(i));
}
catch (Exception e)
{
categoryLabel = this.categories.get(i).toString();
}
Rectangle2D d = fontMetric.getStringBounds(categoryLabel, g2);
int left = (int) (x - d.getWidth() / 2.0);
if (left < 0)
{
g2.drawString(categoryLabel, 0, fontMetric.getHeight());
}
else if (left + d.getWidth() > getWidth())
{
g2.drawString(categoryLabel, (int) (getWidth() - d.getWidth()), fontMetric.getHeight());
}
else
{
g2.drawString(categoryLabel, left, fontMetric.getHeight());
}
}
}
/**
* Retrieve the current probability values.
* @return the probability values
*/
public double[] getProbabilities()
{
double[] result = new double[this.categories.size()];
for (int i = 0; i < this.categories.size(); i++)
{
result[i] = getProbability(i);
}
return result;
}
/**
* Returns the probability of the given category.
* @param t category
* @return the probability of the given category
* @throws IllegalArgumentException if the category object is not part of the distribution
*/
public double getProbability(final T t)
{
Throw.when(!this.categories.contains(t), IllegalArgumentException.class, "Category {} is not part of the distribution.",
t);
return getProbability(this.categories.indexOf(t));
}
/**
* Returns the probability of category with given index.
* @param i category index
* @return the probability of category with given index
* @throws IndexOutOfBoundsException if the index is out of bounds
*/
public double getProbability(final int i)
{
Objects.checkIndex(i, this.categories.size());
if (i == 0)
{
return getValue(0);
}
else if (i == this.categories.size() - 1)
{
return 1.0 - getValue(this.categories.size() - 2);
}
return getValue(i) - getValue(i - 1);
}
}