Skip to content

Commit

Permalink
feat(inference): add a proxy to easily switch models (#99)
Browse files Browse the repository at this point in the history
  • Loading branch information
kaysonyu authored Oct 25, 2024
1 parent b61d9bb commit d22e4a1
Show file tree
Hide file tree
Showing 7 changed files with 260 additions and 5 deletions.
2 changes: 1 addition & 1 deletion modules/inference/example/dev/example_workspace.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
modules:
inference:
path: oci://ghcr.io/kusionstack/inference
version: 0.1.0-beta.5
version: 0.1.0
configs:
default: {}
network:
Expand Down
2 changes: 1 addition & 1 deletion modules/inference/example/dev/kcl.mod
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
name = "example"

[dependencies]
inference = { oci = "oci://ghcr.io/kusionstack/inference", tag = "0.1.0-beta.5" }
inference = { oci = "oci://ghcr.io/kusionstack/inference", tag = "0.1.0" }
service = {oci = "oci://ghcr.io/kusionstack/service", tag = "0.1.0" }
kam = { git = "https://github.com/KusionStack/kam.git", tag = "0.2.0" }
network = { oci = "oci://ghcr.io/kusionstack/network", tag = "0.2.0" }
Expand Down
2 changes: 1 addition & 1 deletion modules/inference/kcl.mod
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
[package]
name = "inference"
version = "0.1.0-beta.5"
version = "0.1.0"
2 changes: 1 addition & 1 deletion modules/inference/src/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ TEST?=$$(go list ./... | grep -v 'vendor')
###### chang variables below according to your own modules ###
NAMESPACE=kusionstack
NAME=inference
VERSION=0.1.0-beta.5
VERSION=0.1.0
BINARY=../bin/kusion-module-${NAME}_${VERSION}

LOCAL_ARCH := $(shell uname -m)
Expand Down
7 changes: 7 additions & 0 deletions modules/inference/src/inference.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,13 @@ var (
OllamaImage = "ollama/ollama"
)

// proxy
var (
ProxyName = "proxy"
ProxyPort = 5000
ProxyImage = "kangy126/proxy"
)

func main() {
server.Start(&Inference{})
}
Expand Down
139 changes: 138 additions & 1 deletion modules/inference/src/ollama_frame.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,21 @@ func (infer *Inference) GenerateOllamaResource(request *module.GeneratorRequest)
}
resources = append(resources, *svc)

patcher, err := infer.GenerateEnv(svcName)
// Build Kubernetes Deployment for proxy.
deploymentProxy, err := infer.generateProxyDeployment(request, svcName)
if err != nil {
return nil, nil, err
}
resources = append(resources, *deploymentProxy)

// Build Kubernetes Service for proxy.
svcProxy, svcNameProxy, err := infer.generateProxyService(request)
if err != nil {
return nil, nil, err
}
resources = append(resources, *svcProxy)

patcher, err := infer.GenerateEnv(svcNameProxy)
if err != nil {
return nil, nil, err
}
Expand Down Expand Up @@ -200,3 +214,126 @@ func (infer *Inference) generateMatchLabels() map[string]string {
"accessory": strings.ToLower(infer.Framework),
}
}

// generateMatchLabels generates the match labels for the Kubernetes resources of proxy.
func (infer *Inference) generateMatchLabelsForProxy() map[string]string {
return map[string]string{
"accessory": strings.ToLower(ProxyName),
}
}

// generatePodSpec generates the Kubernetes PodSpec for proxy.
func (infer *Inference) generateProxyPodSpec(_ *module.GeneratorRequest, svcName string) (v1.PodSpec, error) {
portName := strings.ToLower(ProxyName) + inferContainerPortSuffix
if len(portName) > 15 {
portName = portName[:15]
}
containerPort := int32(ProxyPort)
ports := []v1.ContainerPort{
{
Name: portName,
ContainerPort: containerPort,
},
}

envVars := []v1.EnvVar{
{
Name: "MODEL",
Value: infer.Model,
},
{
Name: "FRAMEWORK_URL",
Value: svcName,
},
}

image := ProxyImage
podSpec := v1.PodSpec{
Containers: []v1.Container{
{
Name: strings.ToLower(ProxyName) + inferContainerSuffix,
Image: image,
Ports: ports,
Env: envVars,
},
},
}
return podSpec, nil
}

// generateDeployment generates the Kubernetes Deployment resource for proxy.
func (infer *Inference) generateProxyDeployment(request *module.GeneratorRequest, svcName string) (*apiv1.Resource, error) {
podSpec, err := infer.generateProxyPodSpec(request, svcName)
if err != nil {
return nil, nil
}

deployment := &appsv1.Deployment{
TypeMeta: metav1.TypeMeta{
Kind: "Deployment",
APIVersion: appsv1.SchemeGroupVersion.String(),
},
ObjectMeta: metav1.ObjectMeta{
Name: strings.ToLower(ProxyName) + inferDeploymentSuffix,
Namespace: request.Project,
},
Spec: appsv1.DeploymentSpec{
Selector: &metav1.LabelSelector{
MatchLabels: infer.generateMatchLabelsForProxy(),
},
Template: v1.PodTemplateSpec{
ObjectMeta: metav1.ObjectMeta{
Labels: infer.generateMatchLabelsForProxy(),
},
Spec: podSpec,
},
},
}

resourceID := module.KubernetesResourceID(deployment.TypeMeta, deployment.ObjectMeta)
resource, err := module.WrapK8sResourceToKusionResource(resourceID, deployment)
if err != nil {
return nil, err
}

return resource, nil
}

// generateService generates the Kubernetes Service resource for proxy.
func (infer *Inference) generateProxyService(request *module.GeneratorRequest) (*apiv1.Resource, string, error) {
svcName := strings.ToLower(ProxyName) + inferServiceSuffix
svcPort := []v1.ServicePort{
{
Port: int32(CalledPort),
TargetPort: intstr.IntOrString{
Type: intstr.Int,
IntVal: int32(ProxyPort),
},
},
}

service := &v1.Service{
TypeMeta: metav1.TypeMeta{
Kind: "Service",
APIVersion: v1.SchemeGroupVersion.String(),
},
ObjectMeta: metav1.ObjectMeta{
Name: svcName,
Namespace: request.Project,
Labels: infer.generateMatchLabelsForProxy(),
},
Spec: v1.ServiceSpec{
Type: v1.ServiceTypeClusterIP,
Ports: svcPort,
Selector: infer.generateMatchLabelsForProxy(),
},
}

resourceID := module.KubernetesResourceID(service.TypeMeta, service.ObjectMeta)
resource, err := module.WrapK8sResourceToKusionResource(resourceID, service)
if err != nil {
return nil, svcName, err
}

return resource, svcName, nil
}
111 changes: 111 additions & 0 deletions modules/inference/src/ollama_frame_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,3 +135,114 @@ func TestInferenceModule_GenerateOllamaService(t *testing.T) {
assert.Equal(t, strings.ToLower(infer.Framework)+inferServiceSuffix, svcName)
assert.NoError(t, err)
}

func TestInferenceModule_GenerateProxyPodSpec(t *testing.T) {
r := &module.GeneratorRequest{
Project: "test-project",
Stack: "test-stack",
App: "test-app",
Workload: &v1.Workload{
Header: v1.Header{
Type: "Service",
},
Service: &v1.Service{},
},
}

infer := &Inference{
Model: "qwen",
Framework: "Ollama",
System: "",
Template: "",
TopK: 40,
TopP: 0.9,
Temperature: 0.8,
NumPredict: 128,
NumCtx: 2048,
}

res, err := infer.generateProxyPodSpec(r, "ollama-svc")

assert.NotNil(t, res)
assert.NoError(t, err)
}

func TestInferenceModule_GenerateProxyDeployment(t *testing.T) {
r := &module.GeneratorRequest{
Project: "test-project",
Stack: "test-stack",
App: "test-app",
Workload: &v1.Workload{
Header: v1.Header{
Type: "Service",
},
Service: &v1.Service{},
},
}

infer := &Inference{
Model: "qwen",
Framework: "Ollama",
System: "",
Template: "",
TopK: 40,
TopP: 0.9,
Temperature: 0.8,
NumPredict: 128,
NumCtx: 2048,
}

res, err := infer.generateProxyDeployment(r, "ollama-svc")

assert.NotNil(t, res)
assert.NoError(t, err)
}

func TestInferenceModule_GenerateProxyService(t *testing.T) {
r := &module.GeneratorRequest{
Project: "test-project",
Stack: "test-stack",
App: "test-app",
Workload: &v1.Workload{
Header: v1.Header{
Type: "Service",
},
Service: &v1.Service{},
},
}

infer := &Inference{
Model: "qwen",
Framework: "Ollama",
System: "",
Template: "",
TopK: 40,
TopP: 0.9,
Temperature: 0.8,
NumPredict: 128,
NumCtx: 2048,
}

res, svcName, err := infer.generateProxyService(r)

assert.NotNil(t, res)
assert.NotNil(t, svcName)
assert.Equal(t, strings.ToLower(ProxyName)+inferServiceSuffix, svcName)
assert.NoError(t, err)
}

func TestInferenceModule_GenerateMatchLabels(t *testing.T) {
t.Run("generate matchLabels", func(t *testing.T) {
infer := &Inference{Framework: "Ollama"}
labels := infer.generateMatchLabels()
assert.Equal(t, strings.ToLower(infer.Framework), labels["accessory"])
})
}

func TestInferenceModule_GenerateMatchLabelsForProxy(t *testing.T) {
t.Run("generate matchLabels for proxy", func(t *testing.T) {
infer := &Inference{}
labels := infer.generateMatchLabelsForProxy()
assert.Equal(t, strings.ToLower(ProxyName), labels["accessory"])
})
}

0 comments on commit d22e4a1

Please sign in to comment.