dex/vendor/google.golang.org/grpc/grpclb/grpclb.go

451 lines
10 KiB
Go

/*
*
* Copyright 2016, Google Inc.
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are
* met:
*
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above
* copyright notice, this list of conditions and the following disclaimer
* in the documentation and/or other materials provided with the
* distribution.
* * Neither the name of Google Inc. nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
* A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
* OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
*/
// Package grpclb implements the load balancing protocol defined at
// https://github.com/grpc/grpc/blob/master/doc/load-balancing.md.
// The implementation is currently EXPERIMENTAL.
package grpclb
import (
"errors"
"fmt"
"sync"
"golang.org/x/net/context"
"google.golang.org/grpc"
lbpb "google.golang.org/grpc/grpclb/grpc_lb_v1"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/naming"
)
// Balancer creates a grpclb load balancer.
func Balancer(r naming.Resolver) grpc.Balancer {
return &balancer{
r: r,
}
}
type remoteBalancerInfo struct {
addr grpc.Address
name string
}
// addrInfo consists of the information of a backend server.
type addrInfo struct {
addr grpc.Address
connected bool
dropRequest bool
}
type balancer struct {
r naming.Resolver
mu sync.Mutex
seq int // a sequence number to make sure addrCh does not get stale addresses.
w naming.Watcher
addrCh chan []grpc.Address
rbs []remoteBalancerInfo
addrs []addrInfo
next int
waitCh chan struct{}
done bool
}
func (b *balancer) watchAddrUpdates(w naming.Watcher, ch chan remoteBalancerInfo) error {
updates, err := w.Next()
if err != nil {
return err
}
b.mu.Lock()
defer b.mu.Unlock()
if b.done {
return grpc.ErrClientConnClosing
}
var bAddr remoteBalancerInfo
if len(b.rbs) > 0 {
bAddr = b.rbs[0]
}
for _, update := range updates {
addr := grpc.Address{
Addr: update.Addr,
Metadata: update.Metadata,
}
switch update.Op {
case naming.Add:
var exist bool
for _, v := range b.rbs {
// TODO: Is the same addr with different server name a different balancer?
if addr == v.addr {
exist = true
break
}
}
if exist {
continue
}
b.rbs = append(b.rbs, remoteBalancerInfo{addr: addr})
case naming.Delete:
for i, v := range b.rbs {
if addr == v.addr {
copy(b.rbs[i:], b.rbs[i+1:])
b.rbs = b.rbs[:len(b.rbs)-1]
break
}
}
default:
grpclog.Println("Unknown update.Op ", update.Op)
}
}
// TODO: Fall back to the basic round-robin load balancing if the resulting address is
// not a load balancer.
if len(b.rbs) > 0 {
// For simplicity, always use the first one now. May revisit this decision later.
if b.rbs[0] != bAddr {
select {
case <-ch:
default:
}
ch <- b.rbs[0]
}
}
return nil
}
func (b *balancer) processServerList(l *lbpb.ServerList, seq int) {
servers := l.GetServers()
var (
sl []addrInfo
addrs []grpc.Address
)
for _, s := range servers {
// TODO: Support ExpirationInterval
addr := grpc.Address{
Addr: fmt.Sprintf("%s:%d", s.IpAddress, s.Port),
// TODO: include LoadBalanceToken in the Metadata
}
sl = append(sl, addrInfo{
addr: addr,
// TODO: Support dropRequest feature.
})
addrs = append(addrs, addr)
}
b.mu.Lock()
defer b.mu.Unlock()
if b.done || seq < b.seq {
return
}
if len(sl) > 0 {
// reset b.next to 0 when replacing the server list.
b.next = 0
b.addrs = sl
b.addrCh <- addrs
}
return
}
func (b *balancer) callRemoteBalancer(lbc lbpb.LoadBalancerClient) (retry bool) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
stream, err := lbc.BalanceLoad(ctx, grpc.FailFast(false))
if err != nil {
grpclog.Printf("Failed to perform RPC to the remote balancer %v", err)
return
}
b.mu.Lock()
if b.done {
b.mu.Unlock()
return
}
b.seq++
seq := b.seq
b.mu.Unlock()
initReq := &lbpb.LoadBalanceRequest{
LoadBalanceRequestType: &lbpb.LoadBalanceRequest_InitialRequest{
InitialRequest: new(lbpb.InitialLoadBalanceRequest),
},
}
if err := stream.Send(initReq); err != nil {
// TODO: backoff on retry?
return true
}
reply, err := stream.Recv()
if err != nil {
// TODO: backoff on retry?
return true
}
initResp := reply.GetInitialResponse()
if initResp == nil {
grpclog.Println("Failed to receive the initial response from the remote balancer.")
return
}
// TODO: Support delegation.
if initResp.LoadBalancerDelegate != "" {
// delegation
grpclog.Println("TODO: Delegation is not supported yet.")
return
}
// Retrieve the server list.
for {
reply, err := stream.Recv()
if err != nil {
break
}
if serverList := reply.GetServerList(); serverList != nil {
b.processServerList(serverList, seq)
}
}
return true
}
func (b *balancer) Start(target string, config grpc.BalancerConfig) error {
// TODO: Fall back to the basic direct connection if there is no name resolver.
if b.r == nil {
return errors.New("there is no name resolver installed")
}
b.mu.Lock()
if b.done {
b.mu.Unlock()
return grpc.ErrClientConnClosing
}
b.addrCh = make(chan []grpc.Address)
w, err := b.r.Resolve(target)
if err != nil {
b.mu.Unlock()
return err
}
b.w = w
b.mu.Unlock()
balancerAddrCh := make(chan remoteBalancerInfo, 1)
// Spawn a goroutine to monitor the name resolution of remote load balancer.
go func() {
for {
if err := b.watchAddrUpdates(w, balancerAddrCh); err != nil {
grpclog.Printf("grpc: the naming watcher stops working due to %v.\n", err)
close(balancerAddrCh)
return
}
}
}()
// Spawn a goroutine to talk to the remote load balancer.
go func() {
var cc *grpc.ClientConn
for {
rb, ok := <-balancerAddrCh
if cc != nil {
cc.Close()
}
if !ok {
// b is closing.
return
}
// Talk to the remote load balancer to get the server list.
//
// TODO: override the server name in creds using Metadata in addr.
var err error
creds := config.DialCreds
if creds == nil {
cc, err = grpc.Dial(rb.addr.Addr, grpc.WithInsecure())
} else {
cc, err = grpc.Dial(rb.addr.Addr, grpc.WithTransportCredentials(creds))
}
if err != nil {
grpclog.Printf("Failed to setup a connection to the remote balancer %v: %v", rb.addr, err)
return
}
go func(cc *grpc.ClientConn) {
lbc := lbpb.NewLoadBalancerClient(cc)
for {
if retry := b.callRemoteBalancer(lbc); !retry {
cc.Close()
return
}
}
}(cc)
}
}()
return nil
}
func (b *balancer) down(addr grpc.Address, err error) {
b.mu.Lock()
defer b.mu.Unlock()
for _, a := range b.addrs {
if addr == a.addr {
a.connected = false
break
}
}
}
func (b *balancer) Up(addr grpc.Address) func(error) {
b.mu.Lock()
defer b.mu.Unlock()
if b.done {
return nil
}
var cnt int
for _, a := range b.addrs {
if a.addr == addr {
if a.connected {
return nil
}
a.connected = true
}
if a.connected {
cnt++
}
}
// addr is the only one which is connected. Notify the Get() callers who are blocking.
if cnt == 1 && b.waitCh != nil {
close(b.waitCh)
b.waitCh = nil
}
return func(err error) {
b.down(addr, err)
}
}
func (b *balancer) Get(ctx context.Context, opts grpc.BalancerGetOptions) (addr grpc.Address, put func(), err error) {
var ch chan struct{}
b.mu.Lock()
if b.done {
b.mu.Unlock()
err = grpc.ErrClientConnClosing
return
}
if len(b.addrs) > 0 {
if b.next >= len(b.addrs) {
b.next = 0
}
next := b.next
for {
a := b.addrs[next]
next = (next + 1) % len(b.addrs)
if a.connected {
addr = a.addr
b.next = next
b.mu.Unlock()
return
}
if next == b.next {
// Has iterated all the possible address but none is connected.
break
}
}
}
if !opts.BlockingWait {
if len(b.addrs) == 0 {
b.mu.Unlock()
err = fmt.Errorf("there is no address available")
return
}
// Returns the next addr on b.addrs for a failfast RPC.
addr = b.addrs[b.next].addr
b.next++
b.mu.Unlock()
return
}
// Wait on b.waitCh for non-failfast RPCs.
if b.waitCh == nil {
ch = make(chan struct{})
b.waitCh = ch
} else {
ch = b.waitCh
}
b.mu.Unlock()
for {
select {
case <-ctx.Done():
err = ctx.Err()
return
case <-ch:
b.mu.Lock()
if b.done {
b.mu.Unlock()
err = grpc.ErrClientConnClosing
return
}
if len(b.addrs) > 0 {
if b.next >= len(b.addrs) {
b.next = 0
}
next := b.next
for {
a := b.addrs[next]
next = (next + 1) % len(b.addrs)
if a.connected {
addr = a.addr
b.next = next
b.mu.Unlock()
return
}
if next == b.next {
// Has iterated all the possible address but none is connected.
break
}
}
}
// The newly added addr got removed by Down() again.
if b.waitCh == nil {
ch = make(chan struct{})
b.waitCh = ch
} else {
ch = b.waitCh
}
b.mu.Unlock()
}
}
}
func (b *balancer) Notify() <-chan []grpc.Address {
return b.addrCh
}
func (b *balancer) Close() error {
b.mu.Lock()
defer b.mu.Unlock()
b.done = true
if b.waitCh != nil {
close(b.waitCh)
}
if b.addrCh != nil {
close(b.addrCh)
}
if b.w != nil {
b.w.Close()
}
return nil
}