Skip to content

Add scorers support in scheduler #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Apr 17, 2025
71 changes: 69 additions & 2 deletions pkg/epp/datastore/datastore.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"errors"
"fmt"
"sync"
"time"

corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/labels"
Expand All @@ -34,7 +35,9 @@ import (
)

const (
ModelNameIndexKey = "spec.modelName"
ModelNameIndexKey = "spec.modelName"
sessionKeepAliveTime = 60 * time.Minute // How long should an idle session be kept alive
sessionKeepAliveCheckFrequency = 15 * time.Minute // How often to check for overly idle sessions
)

var (
Expand Down Expand Up @@ -65,6 +68,9 @@ type Datastore interface {
PodDelete(namespacedName types.NamespacedName)
PodResyncAll(ctx context.Context, ctrlClient client.Client, pool *v1alpha2.InferencePool)

SetPodForSession(sessionID string, pod *backendmetrics.Pod)
GetPodForSession(sessionID string) *backendmetrics.Pod

// Clears the store state, happens when the pool gets deleted.
Clear()
}
Expand All @@ -75,8 +81,12 @@ func NewDatastore(parentCtx context.Context, pmf *backendmetrics.PodMetricsFacto
poolAndModelsMu: sync.RWMutex{},
models: make(map[string]*v1alpha2.InferenceModel),
pods: &sync.Map{},
sessions: &sync.Map{},
pmf: pmf,
}

go store.cleanupSessions(sessionKeepAliveCheckFrequency, sessionKeepAliveTime, parentCtx)

return store
}

Expand All @@ -90,7 +100,9 @@ type datastore struct {
models map[string]*v1alpha2.InferenceModel
// key: types.NamespacedName, value: backendmetrics.PodMetrics
pods *sync.Map
pmf *backendmetrics.PodMetricsFactory
// key: session id, value: *backendmetrics.Pod
sessions *sync.Map
pmf *backendmetrics.PodMetricsFactory
}

func (ds *datastore) Clear() {
Expand Down Expand Up @@ -291,6 +303,61 @@ func (ds *datastore) PodDelete(namespacedName types.NamespacedName) {
}
}

type sessionInfo struct {
pod *backendmetrics.Pod
lru time.Time
}

// cleanup Cleans up the set of stored session information by removing information
// of old sessions.
func (ds *datastore) cleanupSessions(keepAliveCheckFrequency time.Duration, sessionKeepAlive time.Duration, ctx context.Context) {
logger := log.FromContext(ctx)

logger.Info("Session-affinity cleanup started")
ticker := time.NewTicker(keepAliveCheckFrequency)
defer ticker.Stop()

for {
select {
case <-ctx.Done():
logger.Info("Session-affinity cleanup stopped:")
return
case now := <-ticker.C:
logger.Info("Session affinity checking")
ds.sessions.Range(
func(sessionID any, rawSessionInfo any) bool {
if sessionInfo, ok := rawSessionInfo.(*sessionInfo); ok {
if now.Sub(sessionInfo.lru) > sessionKeepAlive {
// Session is stale, remove it
ds.sessions.Delete(sessionID)
}
} else {
// Value is not of the correct type, remove it
ds.sessions.Delete(sessionID)
}
return true
})
}
}
}

func (ds *datastore) SetPodForSession(sessionID string, pod *backendmetrics.Pod) {
ds.sessions.Store(sessionID, &sessionInfo{
pod: pod,
lru: time.Now(),
})
}

func (ds *datastore) GetPodForSession(sessionID string) *backendmetrics.Pod {
if value, ok := ds.sessions.Load(sessionID); ok {
if sessionInfo, ok := value.(*sessionInfo); ok {
return sessionInfo.pod
}
}

return nil
}

func selectorFromInferencePoolSelector(selector map[v1alpha2.LabelKey]v1alpha2.LabelValue) labels.Selector {
return labels.SelectorFromSet(stripLabelKeyAliasFromLabelMap(selector))
}
Expand Down
16 changes: 15 additions & 1 deletion pkg/epp/handlers/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@ import (
"encoding/json"
"fmt"
"strconv"
"strings"
"time"

extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3"
"github.com/google/uuid"
"sigs.k8s.io/controller-runtime/pkg/log"
"sigs.k8s.io/gateway-api-inference-extension/api/v1alpha2"
schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
Expand Down Expand Up @@ -62,12 +64,14 @@ func (s *StreamingServer) HandleRequestBody(
return reqCtx, errutil.Error{Code: errutil.BadConfiguration, Msg: fmt.Sprintf("error getting target model name for model %v", modelObj.Name)}
}
}

llmReq := &schedulingtypes.LLMRequest{
Model: model,
ResolvedTargetModel: modelName,
Critical: modelObj.Spec.Criticality != nil && *modelObj.Spec.Criticality == v1alpha2.Critical,
SessionID: reqCtx.SessionID,
}
logger.V(logutil.DEBUG).Info("LLM request assembled", "model", llmReq.Model, "targetModel", llmReq.ResolvedTargetModel, "critical", llmReq.Critical)
logger.V(logutil.DEBUG).Info("LLM request assembled", "model", llmReq.Model, "targetModel", llmReq.ResolvedTargetModel, "critical", llmReq.Critical, "session id", reqCtx.SessionID)

var err error
// Update target models in the body.
Expand Down Expand Up @@ -132,6 +136,16 @@ func (s *StreamingServer) HandleRequestBody(
func (s *StreamingServer) HandleRequestHeaders(ctx context.Context, reqCtx *RequestContext, req *extProcPb.ProcessingRequest_RequestHeaders) error {
reqCtx.RequestReceivedTimestamp = time.Now()

for _, header := range req.RequestHeaders.Headers.GetHeaders() {
value := string(header.RawValue)
if strings.ToLower(header.Key) == strings.ToLower(SessionIDHeader) && value != "" {
reqCtx.SessionID = value
}
}
if reqCtx.SessionID == "" {
reqCtx.SessionID = uuid.NewString()
}

// an EoS in the request headers means this request has no body or trailers.
if req.RequestHeaders.EndOfStream {
// We will route this request to a random pod as this is assumed to just be a GET
Expand Down
19 changes: 19 additions & 0 deletions pkg/epp/handlers/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ type RequestContext struct {
TargetPod string
TargetEndpoint string
Model string
SessionID string
ResolvedTargetModel string
RequestReceivedTimestamp time.Time
ResponseCompleteTimestamp time.Time
Expand Down Expand Up @@ -108,6 +109,8 @@ const (
TrailerResponseResponsesComplete StreamRequestState = 7
)

const SessionIDHeader = "Session-ID"

func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer) error {
ctx := srv.Context()
logger := log.FromContext(ctx)
Expand Down Expand Up @@ -197,6 +200,16 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer)
loggerTrace.Info("model server is streaming response")
}
}
// Save session is -> pod mapping
allPods := s.datastore.PodGetAll()

for _, pod := range allPods {
if pod.GetPod().NamespacedName.String() == reqCtx.TargetPod {
s.datastore.SetPodForSession(reqCtx.SessionID, pod.GetPod())
break
}
}

reqCtx.RequestState = ResponseRecieved
reqCtx.respHeaderResp = &extProcPb.ProcessingResponse{
Response: &extProcPb.ProcessingResponse_ResponseHeaders{
Expand All @@ -211,6 +224,12 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer)
RawValue: []byte("true"),
},
},
{
Header: &configPb.HeaderValue{
Key: SessionIDHeader,
RawValue: []byte(reqCtx.SessionID),
},
},
},
},
},
Expand Down
17 changes: 13 additions & 4 deletions pkg/epp/scheduling/scheduler.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ package scheduling
import (
"context"
"fmt"
"math/rand"

"sigs.k8s.io/controller-runtime/pkg/log"
backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics"
Expand Down Expand Up @@ -116,21 +115,27 @@ var (
)

func NewScheduler(datastore Datastore) *Scheduler {
sMng := NewScorerMng()
sMng.addScorer(NewSessionAffinityScorer(1, datastore))

return &Scheduler{
datastore: datastore,
criticalRequestFilter: lowLatencyFilter,
sheddableRequestFilter: sheddableRequestFilter,
scorerMng: sMng,
}
}

type Scheduler struct {
datastore Datastore
criticalRequestFilter Filter
sheddableRequestFilter Filter
scorerMng *ScorerMng
}

type Datastore interface {
PodGetAll() []backendmetrics.PodMetrics
GetPodForSession(SessionID string) *backendmetrics.Pod
}

// Schedule finds the target pod based on metrics and the requested lora adapter.
Expand All @@ -154,7 +159,11 @@ func (s *Scheduler) Schedule(ctx context.Context, req *types.LLMRequest) (target
if err != nil || len(pods) == 0 {
return nil, fmt.Errorf("failed to apply filter, resulted %v pods, this should never happen: %w", len(pods), err)
}
logger.V(logutil.DEBUG).Info(fmt.Sprintf("Selecting a random pod from %d candidates: %+v", len(pods), pods))
i := rand.Intn(len(pods))
return pods[i], nil

selectedPod, err := s.scorerMng.scoreTargets(sCtx, pods)
if err != nil {
return nil, fmt.Errorf("failed to apply scorers: %w", err)
}

return selectedPod, nil
}
4 changes: 4 additions & 0 deletions pkg/epp/scheduling/scheduler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -230,3 +230,7 @@ func (fds *fakeDataStore) PodGetAll() []backendmetrics.PodMetrics {
}
return pm
}

func (fds *fakeDataStore) GetPodForSession(sessionID string) *backendmetrics.Pod {
return nil
}
112 changes: 112 additions & 0 deletions pkg/epp/scheduling/scorer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
/*
Copyright 2025 The Kubernetes Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package scheduling

import (
"fmt"
"math/rand/v2"

"sigs.k8s.io/controller-runtime/pkg/log"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
)

type PodScore struct {
Score float64
Pod *types.PodMetrics
}

// Scorer is the interface that scorers must implement
type Scorer interface {
ScoreTargets(ctx *types.Context, pods []*types.PodMetrics) ([]PodScore, error)
}

// Scorer is the interface that scorers must implement
type ScorerMng struct {
scorers []Scorer
}

func NewScorerMng() *ScorerMng {
return &ScorerMng{
scorers: make([]Scorer, 0),
}
}

func (sm *ScorerMng) addScorer(scorer Scorer) {
sm.scorers = append(sm.scorers, scorer)
}

func (sm *ScorerMng) scoreTargets(ctx *types.Context, pods []*types.PodMetrics) (*types.PodMetrics, error) {
logger := log.FromContext(ctx)

podsTotalScore := make(map[*types.PodMetrics]float64)
validPods := make([]*types.PodMetrics, 0)

// initialize zero score for all pods + check that pods are valid
for _, pod := range pods {
if pod == nil || pod.Pod == nil || pod.Metrics == nil {
logger.Info("Invalid/empty pod skipped in scoring process")
} else {
validPods = append(validPods, pod)
podsTotalScore[pod] = 0.0
}
}

if len(validPods) == 0 {
return nil, fmt.Errorf("Empty list of valid pods to score")
}

// add scores from all scorers
for _, scorer := range sm.scorers {
scoredPods, err := scorer.ScoreTargets(ctx, validPods)
if err != nil {
// in case scorer failed - don't use it in the total score, but continue to other scorers
logger.Error(err, "Score targets returned error in scorer")
} else {
for _, scoredPod := range scoredPods {
podsTotalScore[scoredPod.Pod] += scoredPod.Score
}
}
}

// select pod with maximum score, if more than one with the max score - use random pods from the list
var highestScoreTargets []*types.PodMetrics
// score weights cound be negative
maxScore := 0.0
isFirst := true

for pod, score := range podsTotalScore {
if isFirst {
maxScore = score
highestScoreTargets = []*types.PodMetrics{pod}
} else {
if score > maxScore {
maxScore = score
highestScoreTargets = []*types.PodMetrics{pod}
} else if score == maxScore {
highestScoreTargets = append(highestScoreTargets, pod)
}
}
}

// single pod with max score
if len(highestScoreTargets) == 1 {
return highestScoreTargets[0], nil
}

// select random pod from list of pods with max score
return highestScoreTargets[rand.IntN(len(highestScoreTargets))], nil
}
Loading