Comments (15)
I think it is great addition.
@maximsch2 is there a specific framework that you had in mind where this would give better integration?
from torchmetrics.
Yeah, maybe something like this: https://ax.dev/tutorials/tune_cnn.html
But I do imagine that any sort of sweeping requires us to be able to a) select the target metric b) compare two runs to see if the metric
from torchmetrics.
@maximsch2 what do you think about something like this:
_REGISTER = {}
def register(metric, minimize, index=None):
if minimize:
compare_fn = torch.less
init_val = torch.tensor(float("inf"))
else:
compare_fn = torch.greater
init_val = -torch.tensor(float("inf"))
_REGISTER[metric] = (minimize, compare_fn, init_val, index)
register(MeanSquaredError, True)
class MetricCompare:
def __init__(self, metric):
self.base_metric = metric
minimize, compare_fn, init_val, index = _REGISTER[type(metric)]
self._minimize = minimize
self._compare_fn = compare_fn
self._index = index
self._init_val = init_val
self._new_val = deepcopy(init_val)
self._old_val = deepcopy(init_val)
def update(self, *args, **kwargs):
self.base_metric.update(*args, **kwargs)
def compute(self):
self._old_val = self._new_val
val = self.base_metric.compute()
self._new_val = val.detach()
return val
def reset(self):
self.base_metric.reset()
self._new_val = deepcopy(self._init_val)
self._old_val = deepcopy(self._init_val)
@property
def has_improved(self):
if self._index is None:
return self._compare_fn(self._new_val, self._old_val)
else:
return self._compare_fn(self._new_val[index], self._old_val[index])
@property
def minimize(self):
return self._minimize
@property
def maximize(self):
return not self.minimize
metric = MetricCompare(MeanSquaredError())
metric.update(torch.randn(100,), torch.randn(100,))
val = metric.compute()
print(metric.has_improved)
this is basically a wrapper for metrics that adds additional properties that can tell if the metric should be minimized/maximized and after compute
is called if it has improved.
from torchmetrics.
Usually sweeps will be run in a distributed fashion (e.g. schedule runs with different hyperparams separately, compute metric values, pick the one with the best metric), so has_improved
might not be as useful there.
Thinking about it a bit more, just providing a way to convert a metric to optimization value might be enough (with a semantics that we are increasing or decreasing it).
Another example of package for hyperparam optimization that also takes objective: http://hyperopt.github.io/hyperopt/
from torchmetrics.
I'd like to see this implemented as well. We're using PL + Optuna (+ Hydra's plugin_sweeper_optuna) and running into the same problem. Esp. when a metric of a model is configurable.
I think the approach with property direction() -> 'min'/'max'
is simple and would suffice.
While the solutions with wrappers work, I think it'd be good if PL somehow standardized this, so the other HP optimization libraries can integrate this.
from torchmetrics.
Okay, then settle on adding a property to each metric.
- What should it be named?
direction->'min'/'max'
,
minimize->True/False
,
higher_is_better->True/False
- It should not be implemented for all metrics.
ConfusionMatrix
comes to mind where it does not make sense to talk when one if better than another - How do we deal with metric with multi output and metrics with multidim output.
from torchmetrics.
ConfusionMatrix comes to mind where it [min/max] does not make sense
add -> min/max/None
?
- How do we deal with metric with multi output and metrics with multidim output.
Ie. Optuna let's you define a tuple
direction:
- minimize
- maximize
I'd say we don't care for the first iteration and just leave these as None. And we cannot decide anyway on pareto-optimal front.
... and you probably meant multi-dim metric's output, not multidim optimization, right?
For the multidim output, we need a form of reduction.
Can we say that for the first draft, this feature works only form metrics that Loss(y_hat: Tensor, y: Tensor) -> float
?
from torchmetrics.
For multi-output metrics we need ability to extract the value that is actually being optimized over. E.g. some metrics can return value and corresponding threshold (e.g. recall@precision=90%) and we only want to optimize over the actual value.
from torchmetrics.
@maximsch2 @breznak @SkafteNicki how is it going here? do we have a resolution on what to do?
from torchmetrics.
I think we got stuck on more advanced cases (eg. metrics that return more values, as above). While I see it's important to design it well so it works for all usecases in the future, I think we should find a MVP that we can easily implement, otherwise this will likely get stuck.
In practice, what we're running into is that this would ideally be coordinated "API" for pl.metrics
and torchmetrics
.
E.g. some metrics can return value and corresponding threshold (e.g. recall@precision=90%) and we only want to optimize over the actual value.
could you elaborate on this example, please, @maximsch2 ? From what I understand, the metric returns multiple values for several thresholds. But wouldn't the direction still be the same for all of them? (recall -> max ?)
from torchmetrics.
In practice, what we're running into is that this would ideally be coordinated "API" for
pl.metrics
andtorchmetrics
.
@breznak since pl.metrics
will be deprecated in v1.3 of lightning and completely removed from v1.5, we only need to think about the torchmetrics
API.
E.g. some metrics can return value and corresponding threshold (e.g. recall@precision=90%) and we only want to optimize over the actual value.
could you elaborate on this example, please, @maximsch2 ? From what I understand, the metric returns multiple values for several thresholds. But wouldn't the direction still be the same for all of them? (recall -> max ?)
I think what @maximsch2 is referring to, is that metrics such as PrecisionRecallCurve
have 3 outputs:
precision, recall, thresholds = pr_curve(pred, target)
where I basically want to optimize the precision/recall but not the threshold values.
from torchmetrics.
we only need to think about the torchmetrics API.
good to know, thanks! then it should be easier.
precision, recall, thresholds = pr_curve(pred, target)
where I basically want to optimize the precision/recall but not the threshold values.
how about adding a "tell us what is the (1) optimization criterion for you" to the metric, then?
Like precision, recall, thresholds = pr_curve(pred, target, optimize='recall')
Then we have 1 number that represents the "important" results from such metric.
from torchmetrics.
I'm actually thinking that maybe let's defer the multi-output metrics to later as long as we can support those in CompositionalMetric
. E.g. for single-output metrics, we'll provide higher_is_better
, but for multi-output metrics we'll skip it and rely on people doing something like CompositionalMetric(RecallAndThresholdMetric()[0], None, higher_is_better=True)
which will implement the needed functions and return the single value?
from torchmetrics.
I'm for starting small, but doing it rather soon.
Btw, it'd be nice to get people from Optuna/Ray/Ax/etc PL sweepers here, as those might have valuable feedback.
from torchmetrics.
This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.
from torchmetrics.
Related Issues (20)
- SSIM has values larger than 1 HOT 7
- `BinaryPrecisionRecallCurve` computes wrong value if used with logits, even though the docstring says this is supported HOT 1
- Broken source links in documentation of `1.3.0.post0` HOT 2
- RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cpu) HOT 2
- High memory usage of Perplexity metric HOT 2
- `MultitaskWrapper` still cannot be logged HOT 1
- usage consistency on GPU between `MetricCollection` and `FeatureShare` HOT 1
- MeanAveragePrecision doesn't work as expected when using `max_detection_thresholds != [1, 10, 100]` HOT 3
- Custom callable is ignored in retrieval metric aggregation HOT 1
- Explicitly initializing tensors in `float32` in MeanMetric goes against `torch.set_default_dtype`, leading to numerical errors HOT 6
- Add an option to switch distributions order in the KLDivergence. HOT 2
- MetricWrapper for Target Binarization HOT 3
- Total sum of squares formula HOT 3
- Add Support for SQ and RQ in Panoptic Quality HOT 5
- MPS uninitialized memory(?) causing errors in `StatScores` (which cascade to other locations) HOT 1
- Importing torchmetrics causes segmentation fault with other dependencies HOT 2
- `MetricCollection` did not copy inner state of metric in `ClasswiseWrapper` when computing groups metrics HOT 1
- Wrong aggregation of Precision\Recall\F1-Score HOT 2
- Typing error for detection metrics in MultitaskWrapper HOT 2
- Can't access metrics in a MetricCollection via keys returned in MetricCollection.keys
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from torchmetrics.