Comments (1)
I found a solution. The focus is on the LoadData()
part.
public class FooFieldClassifier
{
private readonly FooDbContext _FooDbContext;
private readonly IConfiguration _configuration;
private readonly MLContext _mlContext;
private readonly Lazy<ITransformer> _trainedModelLazy;
private ITransformer? _trainedModel;
private TrainTestData _trainTestData;
public bool IsTrainedModelLoaded => _trainedModelLazy.IsValueCreated;
public string TrainedModelPath => _configuration["TrainedModelPath"] ?? "Foo.model.zip";
public FooFieldClassifier(MLContext mlContext, FooDbContext FooDbContext, IConfiguration configuration)
{
_mlContext = mlContext;
_FooDbContext = FooDbContext;
_configuration = configuration;
_trainedModelLazy = new Lazy<ITransformer>(() => _mlContext.Model.Load(TrainedModelPath, out _));
}
public void Evaluate()
{
Debug.WriteLine($"=============== Evaluating to get model's accuracy metrics - Starting time: {DateTime.Now} ===============");
var testMetrics = _mlContext.MulticlassClassification.Evaluate(_trainedModel?.Transform(_trainTestData.TestSet));
Debug.WriteLine($"=============== Evaluating to get model's accuracy metrics - Ending time: {DateTime.Now} ===============");
Debug.WriteLine($"*************************************************************************************************************");
Debug.WriteLine($"* Metrics for Multi-class Classification model - Test Data ");
Debug.WriteLine($"*------------------------------------------------------------------------------------------------------------");
Debug.WriteLine($"* MicroAccuracy: {testMetrics.MicroAccuracy:0.###}");
Debug.WriteLine($"* MacroAccuracy: {testMetrics.MacroAccuracy:0.###}");
Debug.WriteLine($"* LogLoss: {testMetrics.LogLoss:#.###}");
Debug.WriteLine($"* LogLossReduction: {testMetrics.LogLossReduction:#.###}");
Debug.WriteLine($"*************************************************************************************************************");
}
public void LoadModel()
{
_trainedModel ??= _trainedModelLazy.Value;
}
public FooFieldTypePrediction Predict(string field)
{
LoadModel();
var example = new FooModelInput(field);
var predEngine = _mlContext.Model.CreatePredictionEngine<FooModelInput, FooFieldTypePrediction>(_trainedModel);
var prediction = predEngine.Predict(example);
Debug.WriteLine($"=============== Single Prediction - Result: {field}: {prediction.FooFieldType} ===============");
return prediction;
}
public FooFieldType PredictFooFieldType(string field)
{
field = field.CleanText();
if (String.IsNullOrWhiteSpace(field))
{
return FooFieldType.SmartMode;
}
if (FooCertificateStatusFields.Descriptions.Value.Contains(field))
{
return FooFieldType.Status;
}
if (DateOnly.TryParse(field, out _))
{
return FooFieldType.CertDateStart;
}
if (FooTestingCenterFields.Descriptions.Value.Contains(field))
{
return FooFieldType.TestingCenter;
}
var prediction = Predict(field);
return prediction.FooFieldType;
}
public void SaveModelAsFile(ITransformer model, DataViewSchema trainingDataViewSchema)
{
_mlContext.Model.Save(model, trainingDataViewSchema, TrainedModelPath);
Debug.WriteLine($"The model is saved to {TrainedModelPath}");
}
public void Train()
{
_trainTestData = LoadData();
var pipeline = ProcessData();
BuildAndTrainModel(_trainTestData.TrainSet, pipeline);
}
private TransformerChain<KeyToValueMappingTransformer> BuildAndTrainModel(IDataView splitTrainSet, IEstimator<ITransformer> pipeline)
{
var trainingPipeline = pipeline
.Append(_mlContext.MulticlassClassification.Trainers.SdcaMaximumEntropy("Label", "Feature"))
.Append(_mlContext.Transforms.Conversion.MapKeyToValue("PredictedLabel"));
var trainedModel = trainingPipeline.Fit(splitTrainSet);
return trainedModel;
}
private TrainTestData LoadData()
{
var Foos = _FooDbContext.FooSet.AsEnumerable();
var cats = Foos.Select(x => new FooModelInput(x.ProductCategory, FooFieldType.ProductCategory));
var ents = Foos.Select(x => new FooModelInput(x.Enterprise, FooFieldType.EnterpriseName));
var names = Foos.Select(x => new FooModelInput(x.ProductName, FooFieldType.ProductName));
var models = Foos.SelectMany(x => x.Models).Select(x => new FooModelInput(x, FooFieldType.Model));
var certs = Foos.Select(x => new FooModelInput(x.CertificateNumber, FooFieldType.CertificateNo));
var rpts = Foos.SelectMany(x => x.ReportNumbers).Select(x => new FooModelInput(x, FooFieldType.ReportNo));
var modelInputs = new[] { cats, ents, names, models, certs, rpts }.SelectMany(x => x);
var dataView = _mlContext.Data.LoadFromEnumerable(modelInputs);
var splitDataView = _mlContext.Data.TrainTestSplit(dataView, testFraction: 0.2);
var trainData = splitDataView.TrainSet;
var testData = splitDataView.TestSet;
return splitDataView;
}
private EstimatorChain<ITransformer> ProcessData()
{
var pipeline = _mlContext.Transforms.Conversion
.MapValueToKey(inputColumnName: nameof(FooModelInput.FooFieldType), outputColumnName: "Label")
.Append(_mlContext.Transforms.Text.FeaturizeText(inputColumnName: nameof(FooModelInput.Field), outputColumnName: "Feature"))
.AppendCacheCheckpoint(_mlContext);
return pipeline;
}
}
public class FooModelInput
{
[ColumnName("Label")]
public int FooFieldType { get; set; }
public string Field { get; set; }
public FooModelInput(string field, FooFieldType FooFieldType)
{
Field = field;
FooFieldType = (int)FooFieldType;
}
public FooModelInput(string field)
{
Field = field;
}
}
public class FooFieldTypePrediction
{
public FooFieldType FooFieldType => (FooFieldType)Prediction;
[ColumnName("PredictedLabel")]
public int Prediction { get; set; }
}
from machinelearning-samples.
Related Issues (20)
- [DetectSpikeBySsa] System.AccessViolationException: Attempted to read or write protected memory. This is often an indication that other memory is corrupt.
- Faster processing
- ML.NET Sample similar to Unity's ML Agents? HOT 1
- AutoML - different results for the same input data HOT 2
- AutoML - how to read values of L1Regularization and L2Regularization and selected algorithm name?
- CrossValidate vs Fit method
- AutoML vs Cross validation
- how to inference model onnx with dynamic input sizes in ml.net ?
- Sample AnomalyDetection_PhoneCalls Project does not trigger anomalies with v3.0.1 HOT 1
- Models deployed in Windows 2019 server gets stuck in model.load
- all samples are crashing
- 'MatrixFactorizationTrainer.Options' does not contain a definition for 'K' HOT 1
- Update samples to use Swagger / Swagger UI only in development environment.
- Splitter/consolidator worker encountered exception while consuming source data HOT 1
- Unable to find entry point named "TF_StringEncodedSize" in DLL "tensorFlow" HOT 1
- outdated .net core 2.0.1 references HOT 1
- Accuracy of bike rental sample
- MachineLearning-samples - PredictionEnginePool
- System.OverflowException: 'Value was either too large or too small for an Int32.' GitHubLabeler
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from machinelearning-samples.