Comments (5)
@iksooman you just have to do rm -rf out
before running the code for a new model.
from tinyneuralnetwork.
Is it OK to modify the line 2879 of TinyNeuralNetwork/tinynn/graph/quantization/quantizer.py
?
from
node = graph.nodes_map[name]
to
try:
node = graph.nodes_map[name]
except KeyError:
continue
from tinyneuralnetwork.
@iksooman No, I think the correct thing to do is to remove all the things in the specific directory. since it usually mean that the model is changed or updated. I think the following logic would be better.
if name in graph.nodes_map:
node = graph.nodes_map[name]
else:
log.error(f'Node name {name} not found in configuration file, it probably means that your model has been updated. Please remove the old yaml file and try again')
assert False
from tinyneuralnetwork.
@peterjc123 Could you please give me a little more detailed explanation?
I just replaced the model in the quick_start_for_expert.py
file. Between the part where the model is defined and the execution of quantizer.quantize(), there seems to be only pruning. Do I need to modify the settings in pruning part?
def main_worker(args):
print("###### TinyNeuralNetwork quick start for expert ######")
# If you encounter any problems, please set the global log level to `DEBUG`, which may make it easier to debug.
# set_global_log_level("DEBUG")
model = mobilenet.Mobilenet()
model.load_state_dict(torch.load(mobilenet.DEFAULT_STATE_DICT))
device = get_device()
model.to(device=device)
if args.distillation:
teacher = copy.deepcopy(model)
if args.parallel:
model = nn.DataParallel(model)
# Provide a viable input for the model
dummy_input = torch.rand((1, 3, 224, 224))
context = DLContext()
context.device = device
context.train_loader, context.val_loader = get_dataloader(args.data_path, 224, args.batch_size, args.workers)
print("Validation accuracy of the original model")
validate(model, context)
print("Start pruning the model")
# If you need to set the sparsity of a single operator, then you may refer to the examples in `examples/pruner`.
pruner = OneShotChannelPruner(model, dummy_input, {"sparsity": 0.75, "metrics": "l2_norm"})
st_flops = pruner.calc_flops()
pruner.prune() # Get the pruned model
print("Validation accuracy of the pruned model")
validate(model, context)
ed_flops = pruner.calc_flops()
print(f"Pruning over, reduced FLOPS {100 * (st_flops - ed_flops) / st_flops:.2f}% ({st_flops} -> {ed_flops})")
print("Start finetune the pruned model")
# In our experiments, using the same learning rate configuration as the one used during training from scratch
# leads to a higher final model accuracy.
context.max_epoch = 220
context.criterion = nn.BCEWithLogitsLoss()
context.optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
context.scheduler = CosineAnnealingLR(context.optimizer, T_max=context.max_epoch + 1, eta_min=0)
if args.warmup:
print("Use warmup")
context.warmup_iteration = len(context.train_loader) * 10 # warmup 10 epoch
context.warmup_scheduler = CyclicLR(
context.optimizer, base_lr=0, max_lr=0.1, step_size_up=context.warmup_iteration
)
if args.distillation:
# The utilization of distillation may leads to better accuracy at the price of longer training time.
print("Use distillation")
context.custom_args = {'distill_A': 0.3, 'distill_T': 6, 'distill_teacher': teacher}
train(model, context, train_one_epoch_distill, validate)
else:
train(model, context, train_one_epoch, validate)
print("Start preparing the model for quantization")
# We provides a QATQuantizer class that may rewrite the graph for and perform model fusion for quantization
# The model returned by the `quantize` function is ready for QAT training
quantizer = QATQuantizer(model, dummy_input, work_dir='out')
qat_model = quantizer.quantize()
from tinyneuralnetwork.
@peterjc123 problem solved. Thank you!
from tinyneuralnetwork.
Related Issues (20)
- [converter] map gather(+reshape) ops with seperate consecutive indices to split(unpack) ops
- tinynn.converter module not found! HOT 2
- [CI] several tests for modifier failed
- Whether to support pytorch to keras HOT 1
- TransposeConv wrong shape? HOT 15
- change input to INT8 after converting to tflite HOT 2
- [converter] implement torch's `aten::scaled_dot_product_attention` operator HOT 2
- Request: clamp would be more efficient to go to Bounded Relu than Maximum + Minimum HOT 3
- Do not support PReLU module? HOT 5
- torch.max not working HOT 2
- OneShotChannelPruner results in the miss of some operators HOT 4
- PyTorch 转 TFLite 使用 int8 量化 HOT 4
- Does tinynn support following int16 quantization? HOT 1
- jit.trace succeed but tinynn tracer failed HOT 1
- It became larger after converting to tflite model HOT 4
- how to do Post-training integer quantization with int16 activation HOT 4
- unnecessary float() variables cause quantization to fail. HOT 7
- aten::index nodes take multiple indices in PyTorch model but cause an error when trying to convert to TFLite HOT 1
- aten::repeat_interleave is considered an unsupported Tensor and causing an error when trying to convert to TFLite HOT 2
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from tinyneuralnetwork.