Skip to content

Commit 5d6050f

Browse files
committed
feat: added grpc client, model interface
1 parent 36ac248 commit 5d6050f

File tree

3 files changed

+188
-0
lines changed

3 files changed

+188
-0
lines changed

client/client.go

+111
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
package client
2+
3+
import (
4+
"github.com/dev6699/face/model"
5+
"github.com/dev6699/face/protobuf"
6+
"google.golang.org/grpc"
7+
"google.golang.org/grpc/credentials/insecure"
8+
)
9+
10+
var (
11+
conn *grpc.ClientConn
12+
client protobuf.GRPCInferenceServiceClient
13+
modelsMetadata = make(map[string]*modelMetadata)
14+
)
15+
16+
// Init initializes grpc connection and fetch all models metadata from grpc server.
17+
func Init(url string, models []model.ModelMeta) error {
18+
var err error
19+
conn, err = grpc.NewClient(url, grpc.WithTransportCredentials(insecure.NewCredentials()))
20+
if err != nil {
21+
return err
22+
}
23+
24+
client = protobuf.NewGRPCInferenceServiceClient(conn)
25+
for _, m := range models {
26+
27+
meta, err := newModelMetadata(client, m.ModelName(), m.ModelVersion())
28+
if err != nil {
29+
return err
30+
}
31+
32+
modelsMetadata[m.ModelName()] = meta
33+
}
34+
return nil
35+
}
36+
37+
// Close tears down underlying grpc connection.
38+
func Close() error {
39+
return conn.Close()
40+
}
41+
42+
// Infer is a generic function takes in modelFactory to create model.Model and input for model.PreProcess(),
43+
// and performs infer request based on model metadata automatically.
44+
func Infer[I, O any](modelFactory func() model.Model[I, O], input I) (O, error) {
45+
var zeroOutput O
46+
model := modelFactory()
47+
contents, err := model.PreProcess(input)
48+
if err != nil {
49+
return zeroOutput, err
50+
}
51+
52+
modelInferRequest := modelsMetadata[model.ModelName()].formInferRequest(contents)
53+
54+
inferResponse, err := ModelInferRequest(client, modelInferRequest)
55+
if err != nil {
56+
return zeroOutput, err
57+
}
58+
59+
return model.PostProcess(inferResponse.RawOutputContents)
60+
}
61+
62+
type modelMetadata struct {
63+
modelName string
64+
modelVersion string
65+
*protobuf.ModelMetadataResponse
66+
}
67+
68+
func newModelMetadata(client protobuf.GRPCInferenceServiceClient, modelName string, modelVersion string) (*modelMetadata, error) {
69+
metaResponse, err := ModelMetadataRequest(client, modelName, modelVersion)
70+
if err != nil {
71+
return nil, err
72+
}
73+
74+
return &modelMetadata{
75+
modelName: modelName,
76+
modelVersion: modelVersion,
77+
ModelMetadataResponse: metaResponse,
78+
}, nil
79+
}
80+
81+
func (m *modelMetadata) formInferRequest(contents []*protobuf.InferTensorContents) *protobuf.ModelInferRequest {
82+
83+
inputs := []*protobuf.ModelInferRequest_InferInputTensor{}
84+
for i, c := range contents {
85+
input := m.Inputs[i]
86+
shape := input.Shape
87+
if shape[0] == -1 {
88+
shape[0] = 1
89+
}
90+
inputs = append(inputs, &protobuf.ModelInferRequest_InferInputTensor{
91+
Name: input.Name,
92+
Datatype: input.Datatype,
93+
Shape: shape,
94+
Contents: c,
95+
})
96+
}
97+
98+
outputs := make([]*protobuf.ModelInferRequest_InferRequestedOutputTensor, len(m.Outputs))
99+
for i, o := range m.Outputs {
100+
outputs[i] = &protobuf.ModelInferRequest_InferRequestedOutputTensor{
101+
Name: o.Name,
102+
}
103+
}
104+
105+
return &protobuf.ModelInferRequest{
106+
ModelName: m.modelName,
107+
ModelVersion: m.modelVersion,
108+
Inputs: inputs,
109+
Outputs: outputs,
110+
}
111+
}

client/conn.go

+61
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
package client
2+
3+
import (
4+
"context"
5+
"time"
6+
7+
"github.com/dev6699/face/protobuf"
8+
)
9+
10+
var requestTimeout = 10 * time.Second
11+
12+
func ServerLiveRequest(client protobuf.GRPCInferenceServiceClient) (*protobuf.ServerLiveResponse, error) {
13+
ctx, cancel := context.WithTimeout(context.Background(), requestTimeout)
14+
defer cancel()
15+
16+
serverLiveRequest := protobuf.ServerLiveRequest{}
17+
serverLiveResponse, err := client.ServerLive(ctx, &serverLiveRequest)
18+
if err != nil {
19+
return nil, err
20+
}
21+
return serverLiveResponse, nil
22+
}
23+
24+
func ServerReadyRequest(client protobuf.GRPCInferenceServiceClient) (*protobuf.ServerReadyResponse, error) {
25+
ctx, cancel := context.WithTimeout(context.Background(), requestTimeout)
26+
defer cancel()
27+
28+
serverReadyRequest := protobuf.ServerReadyRequest{}
29+
serverReadyResponse, err := client.ServerReady(ctx, &serverReadyRequest)
30+
if err != nil {
31+
return nil, err
32+
}
33+
return serverReadyResponse, nil
34+
}
35+
36+
func ModelMetadataRequest(client protobuf.GRPCInferenceServiceClient, modelName string, modelVersion string) (*protobuf.ModelMetadataResponse, error) {
37+
ctx, cancel := context.WithTimeout(context.Background(), requestTimeout)
38+
defer cancel()
39+
40+
modelMetadataRequest := protobuf.ModelMetadataRequest{
41+
Name: modelName,
42+
Version: modelVersion,
43+
}
44+
modelMetadataResponse, err := client.ModelMetadata(ctx, &modelMetadataRequest)
45+
if err != nil {
46+
return nil, err
47+
}
48+
return modelMetadataResponse, nil
49+
}
50+
51+
func ModelInferRequest(client protobuf.GRPCInferenceServiceClient, modelInferRequest *protobuf.ModelInferRequest) (*protobuf.ModelInferResponse, error) {
52+
ctx, cancel := context.WithTimeout(context.Background(), requestTimeout)
53+
defer cancel()
54+
55+
modelInferResponse, err := client.ModelInfer(ctx, modelInferRequest)
56+
if err != nil {
57+
return nil, err
58+
}
59+
60+
return modelInferResponse, nil
61+
}

model/model.go

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
package model
2+
3+
import (
4+
"github.com/dev6699/face/protobuf"
5+
)
6+
7+
type Model[I any, O any] interface {
8+
ModelMeta
9+
PreProcess(input I) ([]*protobuf.InferTensorContents, error)
10+
PostProcess(rawOutputContents [][]byte) (O, error)
11+
}
12+
13+
type ModelMeta interface {
14+
ModelName() string
15+
ModelVersion() string
16+
}

0 commit comments

Comments
 (0)