Comments (1)
The tf.keras
supports such usage by default, taking the following code snippet as an example:
def input_dataset(args, filenames, batch_size):
r'''Get input dataset.
'''
with tf.device('/cpu:0'):
ds = hb.data.ParquetDataset(
filenames,
batch_size=batch_size,
num_parallel_reads=len(filenames),
num_parallel_parser_calls=args.num_parsers,
drop_remainder=True)
ds = ds.apply(hb.data.parse())
ds = ds.map(
lambda batch: (
{f: batch[f] for f in batch if f not in ('ts', 'label')},
{'output_1': tf.reshape(batch['label'], shape=[-1, 1]),
'output_2': tf.reshape(batch['label'], shape=[-1, 1])}))
ds = ds.prefetch(args.num_prefetches)
return ds
It is worth noting that we must use a dict with keys following the pattern of {'output_1': label1, 'output_2': label2 ... }. Here the keys must be written as output_${i}
(requested by tf.keras
), where ${i} is the index of the ith output in the functional API. Accordingly, the creation of model via the functional API is
your_model = tf.keras.Model(inputs=xxx, outputs=[output_a, output_b])
Therefore, tf.keras
implicitly would produce a dict of {'output_1': output_a, 'output_2': output_b} and associate it with the multiple labels {'output_1': label1, 'output_2': label2}.
from hybridbackend.
Related Issues (20)
- Following the BUILD.md tutorial, something is wrong
- How to place the embeddings on gpu?
- ParquetDataset should be able to skip corrupted data
- QR code is invalid HOT 2
- Row-wise shuffling required
- EmbeddingLookupRewritingForDeepRecEV Add "part0" to op-name twice
- hb.keras.model evaluate error
- init_from_checkpoint throw Exception when using hb.keras.Model HOT 1
- hb.data.ParquetDataset will discard some data
- Failed to train with multiple GPUs in single node
- Deeprec hangs in distributed mode.
- Throughput is lower than TFRecords when there are many strings in Parquets file
- Exception occurs when call `batch` with ragged tensor
- No OpKernel was registered to support Op 'HbSparseSegmentMeanGrad1' used by node
- hb.data.ParquetDataset in hb.estimator.train_and_evaluate will loss data HOT 1
- Train got error died with <Signals.SIGSEGV: 11> HOT 3
- Error in multi-card in a single machine mode
- Training is very slow HOT 4
- ParquetDataset support configuration with default value
- Op type not registered 'HbGetNcclId' in binary
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from hybridbackend.