Coder Social home page Coder Social logo

Comments (15)

SkafteNicki avatar SkafteNicki commented on May 20, 2024

I think it is great addition.
@maximsch2 is there a specific framework that you had in mind where this would give better integration?

from torchmetrics.

maximsch2 avatar maximsch2 commented on May 20, 2024

Yeah, maybe something like this: https://ax.dev/tutorials/tune_cnn.html

But I do imagine that any sort of sweeping requires us to be able to a) select the target metric b) compare two runs to see if the metric

from torchmetrics.

SkafteNicki avatar SkafteNicki commented on May 20, 2024

@maximsch2 what do you think about something like this:

_REGISTER = {}

def register(metric, minimize, index=None):
    if minimize:
        compare_fn = torch.less
        init_val = torch.tensor(float("inf"))
    else:
        compare_fn = torch.greater
        init_val = -torch.tensor(float("inf"))
    _REGISTER[metric] = (minimize, compare_fn, init_val, index)

register(MeanSquaredError, True)


class MetricCompare:
    def __init__(self, metric):
        self.base_metric = metric
        minimize, compare_fn, init_val, index = _REGISTER[type(metric)]
        self._minimize = minimize
        self._compare_fn = compare_fn
        self._index = index
        self._init_val = init_val
        self._new_val = deepcopy(init_val)
        self._old_val = deepcopy(init_val)
        
    def update(self, *args, **kwargs):
        self.base_metric.update(*args, **kwargs)

    def compute(self):
        self._old_val = self._new_val
        val = self.base_metric.compute()
        self._new_val = val.detach()
        return val
    
    def reset(self):
        self.base_metric.reset()
        self._new_val = deepcopy(self._init_val)
        self._old_val = deepcopy(self._init_val)

    @property
    def has_improved(self):
        if self._index is None:
            return self._compare_fn(self._new_val, self._old_val)
        else:
            return self._compare_fn(self._new_val[index], self._old_val[index])

    @property
    def minimize(self):
        return self._minimize
    
    @property
    def maximize(self):
        return not self.minimize

metric = MetricCompare(MeanSquaredError())
metric.update(torch.randn(100,), torch.randn(100,))
val = metric.compute()
print(metric.has_improved)

this is basically a wrapper for metrics that adds additional properties that can tell if the metric should be minimized/maximized and after compute is called if it has improved.

from torchmetrics.

maximsch2 avatar maximsch2 commented on May 20, 2024

Usually sweeps will be run in a distributed fashion (e.g. schedule runs with different hyperparams separately, compute metric values, pick the one with the best metric), so has_improved might not be as useful there.

Thinking about it a bit more, just providing a way to convert a metric to optimization value might be enough (with a semantics that we are increasing or decreasing it).

Another example of package for hyperparam optimization that also takes objective: http://hyperopt.github.io/hyperopt/

from torchmetrics.

breznak avatar breznak commented on May 20, 2024

I'd like to see this implemented as well. We're using PL + Optuna (+ Hydra's plugin_sweeper_optuna) and running into the same problem. Esp. when a metric of a model is configurable.

I think the approach with property direction() -> 'min'/'max' is simple and would suffice.

While the solutions with wrappers work, I think it'd be good if PL somehow standardized this, so the other HP optimization libraries can integrate this.

from torchmetrics.

SkafteNicki avatar SkafteNicki commented on May 20, 2024

Okay, then settle on adding a property to each metric.

  1. What should it be named?
    direction->'min'/'max',
    minimize->True/False,
    higher_is_better->True/False
  2. It should not be implemented for all metrics. ConfusionMatrix comes to mind where it does not make sense to talk when one if better than another
  3. How do we deal with metric with multi output and metrics with multidim output.

from torchmetrics.

breznak avatar breznak commented on May 20, 2024

ConfusionMatrix comes to mind where it [min/max] does not make sense

add -> min/max/None?

  1. How do we deal with metric with multi output and metrics with multidim output.

Ie. Optuna let's you define a tuple

direction: 
 - minimize
 - maximize

I'd say we don't care for the first iteration and just leave these as None. And we cannot decide anyway on pareto-optimal front.

... and you probably meant multi-dim metric's output, not multidim optimization, right?
For the multidim output, we need a form of reduction.

Can we say that for the first draft, this feature works only form metrics that Loss(y_hat: Tensor, y: Tensor) -> float ?

from torchmetrics.

maximsch2 avatar maximsch2 commented on May 20, 2024

For multi-output metrics we need ability to extract the value that is actually being optimized over. E.g. some metrics can return value and corresponding threshold (e.g. recall@precision=90%) and we only want to optimize over the actual value.

from torchmetrics.

Borda avatar Borda commented on May 20, 2024

@maximsch2 @breznak @SkafteNicki how is it going here? do we have a resolution on what to do?

from torchmetrics.

breznak avatar breznak commented on May 20, 2024

I think we got stuck on more advanced cases (eg. metrics that return more values, as above). While I see it's important to design it well so it works for all usecases in the future, I think we should find a MVP that we can easily implement, otherwise this will likely get stuck.

In practice, what we're running into is that this would ideally be coordinated "API" for pl.metrics and torchmetrics.

E.g. some metrics can return value and corresponding threshold (e.g. recall@precision=90%) and we only want to optimize over the actual value.

could you elaborate on this example, please, @maximsch2 ? From what I understand, the metric returns multiple values for several thresholds. But wouldn't the direction still be the same for all of them? (recall -> max ?)

from torchmetrics.

SkafteNicki avatar SkafteNicki commented on May 20, 2024

In practice, what we're running into is that this would ideally be coordinated "API" for pl.metrics and torchmetrics.

@breznak since pl.metrics will be deprecated in v1.3 of lightning and completely removed from v1.5, we only need to think about the torchmetrics API.

E.g. some metrics can return value and corresponding threshold (e.g. recall@precision=90%) and we only want to optimize over the actual value.

could you elaborate on this example, please, @maximsch2 ? From what I understand, the metric returns multiple values for several thresholds. But wouldn't the direction still be the same for all of them? (recall -> max ?)

I think what @maximsch2 is referring to, is that metrics such as PrecisionRecallCurve have 3 outputs:

precision, recall, thresholds = pr_curve(pred, target)

where I basically want to optimize the precision/recall but not the threshold values.

from torchmetrics.

breznak avatar breznak commented on May 20, 2024

we only need to think about the torchmetrics API.

good to know, thanks! then it should be easier.

precision, recall, thresholds = pr_curve(pred, target)
where I basically want to optimize the precision/recall but not the threshold values.

how about adding a "tell us what is the (1) optimization criterion for you" to the metric, then?
Like precision, recall, thresholds = pr_curve(pred, target, optimize='recall')
Then we have 1 number that represents the "important" results from such metric.

from torchmetrics.

maximsch2 avatar maximsch2 commented on May 20, 2024

I'm actually thinking that maybe let's defer the multi-output metrics to later as long as we can support those in CompositionalMetric. E.g. for single-output metrics, we'll provide higher_is_better, but for multi-output metrics we'll skip it and rely on people doing something like CompositionalMetric(RecallAndThresholdMetric()[0], None, higher_is_better=True) which will implement the needed functions and return the single value?

from torchmetrics.

breznak avatar breznak commented on May 20, 2024

I'm for starting small, but doing it rather soon.
Btw, it'd be nice to get people from Optuna/Ray/Ax/etc PL sweepers here, as those might have valuable feedback.

from torchmetrics.

stale avatar stale commented on May 20, 2024

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

from torchmetrics.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.