Coder Social home page Coder Social logo

Comments (1)

CodingOctocat avatar CodingOctocat commented on June 12, 2024

I found a solution. The focus is on the LoadData() part.

public class FooFieldClassifier
{
    private readonly FooDbContext _FooDbContext;

    private readonly IConfiguration _configuration;

    private readonly MLContext _mlContext;

    private readonly Lazy<ITransformer> _trainedModelLazy;

    private ITransformer? _trainedModel;

    private TrainTestData _trainTestData;

    public bool IsTrainedModelLoaded => _trainedModelLazy.IsValueCreated;

    public string TrainedModelPath => _configuration["TrainedModelPath"] ?? "Foo.model.zip";

    public FooFieldClassifier(MLContext mlContext, FooDbContext FooDbContext, IConfiguration configuration)
    {
        _mlContext = mlContext;
        _FooDbContext = FooDbContext;
        _configuration = configuration;

        _trainedModelLazy = new Lazy<ITransformer>(() => _mlContext.Model.Load(TrainedModelPath, out _));
    }

    public void Evaluate()
    {
        Debug.WriteLine($"=============== Evaluating to get model's accuracy metrics - Starting time: {DateTime.Now} ===============");

        var testMetrics = _mlContext.MulticlassClassification.Evaluate(_trainedModel?.Transform(_trainTestData.TestSet));

        Debug.WriteLine($"=============== Evaluating to get model's accuracy metrics - Ending time: {DateTime.Now} ===============");
        Debug.WriteLine($"*************************************************************************************************************");
        Debug.WriteLine($"*       Metrics for Multi-class Classification model - Test Data     ");
        Debug.WriteLine($"*------------------------------------------------------------------------------------------------------------");
        Debug.WriteLine($"*       MicroAccuracy:    {testMetrics.MicroAccuracy:0.###}");
        Debug.WriteLine($"*       MacroAccuracy:    {testMetrics.MacroAccuracy:0.###}");
        Debug.WriteLine($"*       LogLoss:          {testMetrics.LogLoss:#.###}");
        Debug.WriteLine($"*       LogLossReduction: {testMetrics.LogLossReduction:#.###}");
        Debug.WriteLine($"*************************************************************************************************************");
    }

    public void LoadModel()
    {
        _trainedModel ??= _trainedModelLazy.Value;
    }

    public FooFieldTypePrediction Predict(string field)
    {
        LoadModel();

        var example = new FooModelInput(field);

        var predEngine = _mlContext.Model.CreatePredictionEngine<FooModelInput, FooFieldTypePrediction>(_trainedModel);

        var prediction = predEngine.Predict(example);

        Debug.WriteLine($"=============== Single Prediction - Result: {field}: {prediction.FooFieldType} ===============");

        return prediction;
    }

    public FooFieldType PredictFooFieldType(string field)
    {
        field = field.CleanText();

        if (String.IsNullOrWhiteSpace(field))
        {
            return FooFieldType.SmartMode;
        }

        if (FooCertificateStatusFields.Descriptions.Value.Contains(field))
        {
            return FooFieldType.Status;
        }

        if (DateOnly.TryParse(field, out _))
        {
            return FooFieldType.CertDateStart;
        }

        if (FooTestingCenterFields.Descriptions.Value.Contains(field))
        {
            return FooFieldType.TestingCenter;
        }

        var prediction = Predict(field);

        return prediction.FooFieldType;
    }

    public void SaveModelAsFile(ITransformer model, DataViewSchema trainingDataViewSchema)
    {
        _mlContext.Model.Save(model, trainingDataViewSchema, TrainedModelPath);

        Debug.WriteLine($"The model is saved to {TrainedModelPath}");
    }

    public void Train()
    {
        _trainTestData = LoadData();
        var pipeline = ProcessData();
        BuildAndTrainModel(_trainTestData.TrainSet, pipeline);
    }

    private TransformerChain<KeyToValueMappingTransformer> BuildAndTrainModel(IDataView splitTrainSet, IEstimator<ITransformer> pipeline)
    {
        var trainingPipeline = pipeline
            .Append(_mlContext.MulticlassClassification.Trainers.SdcaMaximumEntropy("Label", "Feature"))
            .Append(_mlContext.Transforms.Conversion.MapKeyToValue("PredictedLabel"));

        var trainedModel = trainingPipeline.Fit(splitTrainSet);

        return trainedModel;
    }

    private TrainTestData LoadData()
    {
        var Foos = _FooDbContext.FooSet.AsEnumerable();

        var cats = Foos.Select(x => new FooModelInput(x.ProductCategory, FooFieldType.ProductCategory));
        var ents = Foos.Select(x => new FooModelInput(x.Enterprise, FooFieldType.EnterpriseName));
        var names = Foos.Select(x => new FooModelInput(x.ProductName, FooFieldType.ProductName));
        var models = Foos.SelectMany(x => x.Models).Select(x => new FooModelInput(x, FooFieldType.Model));
        var certs = Foos.Select(x => new FooModelInput(x.CertificateNumber, FooFieldType.CertificateNo));
        var rpts = Foos.SelectMany(x => x.ReportNumbers).Select(x => new FooModelInput(x, FooFieldType.ReportNo));

        var modelInputs = new[] { cats, ents, names, models, certs, rpts }.SelectMany(x => x);

        var dataView = _mlContext.Data.LoadFromEnumerable(modelInputs);

        var splitDataView = _mlContext.Data.TrainTestSplit(dataView, testFraction: 0.2);
        var trainData = splitDataView.TrainSet;
        var testData = splitDataView.TestSet;

        return splitDataView;
    }

    private EstimatorChain<ITransformer> ProcessData()
    {
        var pipeline = _mlContext.Transforms.Conversion
            .MapValueToKey(inputColumnName: nameof(FooModelInput.FooFieldType), outputColumnName: "Label")
            .Append(_mlContext.Transforms.Text.FeaturizeText(inputColumnName: nameof(FooModelInput.Field), outputColumnName: "Feature"))
            .AppendCacheCheckpoint(_mlContext);

        return pipeline;
    }
}
public class FooModelInput
{
    [ColumnName("Label")]
    public int FooFieldType { get; set; }

    public string Field { get; set; }

    public FooModelInput(string field, FooFieldType FooFieldType)
    {
        Field = field;
        FooFieldType = (int)FooFieldType;
    }

    public FooModelInput(string field)
    {
        Field = field;
    }
}
public class FooFieldTypePrediction
{
    public FooFieldType FooFieldType => (FooFieldType)Prediction;

    [ColumnName("PredictedLabel")]
    public int Prediction { get; set; }
}

from machinelearning-samples.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.