Comments (5)
Current Plan (looking for feedback)
- Create an
EvalDataset
(better name suggestion?) data structure to store at least a dataloader (+ relevant Hparams) and a list of metrics for each evaluation dataset. - Trainer is passed an
EvalDataset
or aList[EvalDataset]
Questions + Fields that need changing
compute_training_metrics
(trainer) - since metrics would becomeEvalDataset
specific, would need to store which metrics are going to be used for these operations- Do evaluation parameters like
eval_batch_size
change for different datasets? Should this also be a part ofEvalDataset
- For validating every N batches, would it be helpful to store a flag in
EvalDataset
that determines whether a particular dataset is used? (default=True)
from composer.
I agree with the need for an EvalDataset
(suggestion, rename to Evaluator
) that stores more information than the current val_dataset
hparams object does. In general, I like the idea of things like validate_every_n_batches
being moved into this object, as long as we can continue to make it reasonable to use our trainer without getting hparams involved (something I'm about to take a stab at). The main question I have is how would the metrics
function of the model determine which Evaluator
is invoking it (and which metrics to calculate); I think it would be OK for metrics
to take a second optional parameter with the ID of the invoking Evaluator
.
Also hoping to see feedback from @ajaysaini725, who IIRC contributed significantly to the metrics code before launch.
from composer.
Could we perhaps move the metrics and validate functions inside the evaluator and remove it from the model? Or we could use model metrics as defaults if none are specified by an Evaluator
so a user can optionally just specify the metrics once? @jbloxham @ajaysaini725
from composer.
A few things to consider:
- The output of
model.validate(eval_data_batch) -> Tuple[Any, Any]
gets passed straight into each metrics returned bymodel.metrics()
so thevalidate()
function needs to be compatible with allEvalDataset
s. Also, we should definitely keepvalidate()
as a part of the model definition because it relies on the forward pass logic of the model. - Ideally we don't want to specify anything data-related in the model class because the model class definition should be an isolated unit that works across many different training runs. It's better to have each
EvalDataset
specify the metrics it needs. The model class definitions should still have ametrics()
function that returns all metrics which can be computed for the model and then eachEvalDataset
can specify exactly which metrics are relevant for that dataset (this can be done inDatasetHparams
and also should be stored inDataloaderSpec
so that it works with theTrainer
__init__
. Having the relevant metrics for a model be returned by themodel.metrics()
function prevents anEvalDataset
from specifying a metric that doesn't work for a particular model (i.e. themodel.metrics()
specifies all metrics for a model and thenEvalDataset
specifies the subset of metrics to use for that dataset). IfEvalDataset
doesn't specify anything then the default is to use all metrics.
Also, another thought: what if we move the train_dataset
and eval_dataset
to be parameters of Trainer.fit()
rather than parameters of Trainer
instantiation?
from composer.
Makes sense - I like the idea of making the datasets parameters to the fit function
from composer.
Related Issues (20)
- Allow smaller `t_max` in schedulers HOT 8
- Crash during import when signal handler cannot be set
- Memory leak due to the copy of Metric objects in Composer's trainer HOT 2
- override serialized fields when resuming from checkpoint HOT 1
- Allow default credentials for accessing GCS HOT 4
- Remote file name in `MemorySnapshot` not being formatted HOT 1
- Augment training batches with "on-the-fly" features
- Safer Checkpoint File Format HOT 1
- Streaming Dataloader with Multiple Workers causes out-of-memory error
- LookSAM Optimizer HOT 3
- Remove logic that pops labels from batch in composer/models/huggingface.py
- Composer default lr scheduler create spurious warnings when max training duration is in epochs HOT 2
- Error when using `Trainer.compile_config={}` in DDP mode HOT 4
- Enable control over MLFlowLogger run_name str to match a pre-existing tag run_name in MLFlow and resume model training HOT 1
- TORCH_NCCL_ASYNC_ERROR_HANDLING vs NCCL_ASYNC_ERROR_HANDLING at launcher.py HOT 3
- NUMA affinity control HOT 2
- Optional `CheckpointSaver` instantiation inside the `Trainer` HOT 9
- TypeError: Subscripted generics cannot be used with class and instance checks HOT 2
- Autoresume and duration mismatch on reload HOT 12
- CUDA OOM error not caught with auto microbatching HOT 3
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from composer.