Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lower Level api #23

Open
basejn opened this issue Oct 1, 2018 · 9 comments
Open

Lower Level api #23

basejn opened this issue Oct 1, 2018 · 9 comments

Comments

@basejn
Copy link

basejn commented Oct 1, 2018

Hi i want to use BrightWire to implement the following simple interface.

The problem is that the library is very strongly coupled with data source.

 interface IBaseTrainableProbabilisticMulticategoryModel
    {
        //Construct with  NumClasses and NumFeatures

        int[] NumClasses { get; }
        int NumFeatures { get; }
        /// <summary>
        /// Train a batch 
        /// </summary>
        /// <param name="batchData"> [N,numfeatures]</param>
        /// <param name="batchLabels"> [N,subclasses] = clas id</param>
        /// <param name="weights"> [N,subclasses] = weight for subclass. used to balance classes. weights each data point contribution to the loss for each subclass</param>
        /// <returns>Loss</returns>
        float Train(float[][] batchData, int[][] batchLabels, float[][] weights);

        /// <summary>
        /// predict a batch
        /// </summary>
        /// <returns>[subclasses,N,probability distribution]</returns>
        float[][][] Predict(float[][] batchData);
    }

Is there a lower level api where i can implement a feed forward neural network model , that holds its state(weights) , can be serializable , and is always ready for training and prediction.

This is easily achievable with CNTK and Tenroflow. Their problem is that they only suppoort 64 bit.

@jdermody
Copy link
Owner

jdermody commented Oct 1, 2018

No, there's no lower level API unfortunately. But you could adapt your data, set the batchsize to N and train for 1 iteration to get what you want...

You would need to implement an IDataSource that accepted the first two arguments to your Train method (something like VectorDataSource).

For your weights argument you would need to implement an IAction that weighted the error signal based on the weights (something like ConstrainSignal) and add it to the graph just before the backpropagation.

@basejn
Copy link
Author

basejn commented Oct 2, 2018

Thank you . I will give it a try.
The only thing that worries me is , can i have N groups of softmax outputs, N probability distributions. For eaxmple : One thing can have 3 sub categories . [dog ,cat, horse] [male , female] [white , brown , black] .

@jdermody
Copy link
Owner

jdermody commented Oct 2, 2018

You would need to create a custom Softmax activation to achieve that. You could use the existing softmax activation as a guide but instead of calculating softmax over the entire matrix it could split the matrix both forward and backward on each distribution (GetNewMatrixFromColumns) and then combine them again afterwards as the output.

@PatriceDargenton
Copy link

Hi, how is it possible to use nuget packages if "No, there's no lower level API" ? Is there an example using this package ?
https://www.nuget.org/packages/BrightWire/
https://www.nuget.org/packages/BrightWire.Net4/

@PatriceDargenton
Copy link

Ok, there is no problem to test all these examples in their original solution. But, if I put a nuget package in my personal project for example on BrightWire.Net4, and try to compile the XOR example, then it is not possible to use DataTableBuilder because its accessibility is friend and not public: impossible to compile in a personal project!

@jdermody
Copy link
Owner

I see your point - the current design is that everything is created through indirection and the classes themselves are not public from the assembly. The drawback of this is that the usage is less obvious.

For example in BrightWire.Net4 you create a data table builder like this:

var dataTableBuilder = BrightWireProvider.CreateDataTableBuilder();
dataTableBuilder.AddColumn(ColumnType.Float, "capital costs");
dataTableBuilder.AddColumn(ColumnType.Float, "labour costs");
dataTableBuilder.AddColumn(ColumnType.Float, "energy costs");
dataTableBuilder.AddColumn(ColumnType.Float, "output", true);

In BrightWire (,net core) you do it like this:

var context = new BrightDataContext();
var dataTableBuilder = context.BuildTable();
dataTableBuilder.AddColumn(BrightDataType.Float, "capital costs");
dataTableBuilder.AddColumn(BrightDataType.Float, "labour costs");
dataTableBuilder.AddColumn(BrightDataType.Float, "energy costs");
dataTableBuilder.AddColumn(BrightDataType.Float, "output").SetTarget(true);

BrightWire (.net core) is a more consistent in that everything is available through the context, which acts as an extension point via extension methods in other assemblies. I suppose this approach is one of framework rather than library, but perhaps the better design is to be both.

@PatriceDargenton
Copy link

Ok thanks!, now I have a problem with:
var testData = graph.CreateDataSource(data);
Visual Studio recommends using one of these 3 possibilities, but none can be cast from BrightWire.TabularData.Helper.DataTableWriter (InvalidCastException):
var testData = graph.CreateDataSource(CType(data, IReadOnlyList(Of Models.FloatTensor)));
var testData = graph.CreateDataSource(CType(data, IReadOnlyList(Of Models.FloatMatrix)));
var testData = graph.CreateDataSource(CType(data, IReadOnlyList(Of Models.FloatVector)));

@PatriceDargenton
Copy link

Ok, I found out how it works, thank you very much! Here it is with BrightWire.Net4 in VB .NET:

Public Sub TestXOR()

    Dim lap = BrightWireProvider.CreateLinearAlgebra
    ' Some training data that the network will learn. The XOR pattern looks like:
    ' 0 0 => 0
    ' 1 0 => 1
    ' 0 1 => 1
    ' 1 1 => 0
    Dim builder = BrightWireProvider.CreateDataTableBuilder()
    builder.AddColumn(ColumnType.Float, "X")
    builder.AddColumn(ColumnType.Float, "Y")
    builder.AddColumn(ColumnType.Float, "XOR", True)
    builder.Add(0.0!, 0.0!, 0.0!)
    builder.Add(1.0!, 0.0!, 1.0!)
    builder.Add(0.0!, 1.0!, 1.0!)
    builder.Add(1.0!, 1.0!, 0.0!)
    Dim data = builder
    Dim dataTable = builder.Build()

    ' create the graph
    Dim graph = New GraphFactory(lap)
    Dim errorMetric = graph.ErrorMetric.CrossEntropy
    graph.CurrentPropertySet.Use(graph.GradientDescent.RmsProp)
    ' and gaussian weight initialisation
    graph.CurrentPropertySet.Use(graph.WeightInitialisation.Gaussian)
    ' create the engine
    Dim testData = graph.CreateDataSource(dataTable)

    Dim engine = graph.CreateTrainingEngine(testData, learningRate:=0.1!, batchSize:=4)
    ' create the network
    Const HIDDEN_LAYER_SIZE As Integer = 6
    With graph.Connect(engine)
        ' create a feed forward layer with sigmoid activation
        .AddFeedForward(HIDDEN_LAYER_SIZE).Add(graph.SigmoidActivation)
        ' create a second feed forward layer with sigmoid activation
        .AddFeedForward(engine.DataSource.OutputSize).Add(graph.SigmoidActivation)
        ' calculate the error and backpropagate the error signal
        .AddBackpropagation(errorMetric)
    End With

    ' train the network
    Dim executionContext = graph.CreateExecutionContext
    Dim i = 0
    Do While i < 1000
        engine.Train(executionContext)
        If i Mod 100 = 0 Then engine.Test(testData, errorMetric)
        i += 1
    Loop

    engine.Test(testData, errorMetric)
    ' create a new network to execute the learned network
    Dim networkGraph = engine.Graph
    Dim executionEngine = graph.CreateEngine(networkGraph)
    Dim output = executionEngine.Execute(testData)
    Debug.WriteLine("Average output = " &
        output.Average((Function(o) o.CalculateError(errorMetric))))

    ' print the values that have been learned
    For Each item In output
        For Each index In item.MiniBatchSequence.MiniBatch.Rows
            Dim row = dataTable.GetRow(index)
            Dim result = item.Output(index)
            Dim rOutput! = result.Data(0)
            Debug.WriteLine(
                row.GetField(Of Single)(0) & " XOR " &
                row.GetField(Of Single)(1) & " = " & rOutput.ToString("0.00000"))
        Next
    Next

End Sub

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants