forked from mystiq/dex
451 lines
10 KiB
Go
451 lines
10 KiB
Go
|
/*
|
||
|
*
|
||
|
* Copyright 2016, Google Inc.
|
||
|
* All rights reserved.
|
||
|
*
|
||
|
* Redistribution and use in source and binary forms, with or without
|
||
|
* modification, are permitted provided that the following conditions are
|
||
|
* met:
|
||
|
*
|
||
|
* * Redistributions of source code must retain the above copyright
|
||
|
* notice, this list of conditions and the following disclaimer.
|
||
|
* * Redistributions in binary form must reproduce the above
|
||
|
* copyright notice, this list of conditions and the following disclaimer
|
||
|
* in the documentation and/or other materials provided with the
|
||
|
* distribution.
|
||
|
* * Neither the name of Google Inc. nor the names of its
|
||
|
* contributors may be used to endorse or promote products derived from
|
||
|
* this software without specific prior written permission.
|
||
|
*
|
||
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||
|
* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||
|
* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||
|
* A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||
|
* OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||
|
* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||
|
* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||
|
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||
|
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||
|
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||
|
*
|
||
|
*/
|
||
|
|
||
|
// Package grpclb implements the load balancing protocol defined at
|
||
|
// https://github.com/grpc/grpc/blob/master/doc/load-balancing.md.
|
||
|
// The implementation is currently EXPERIMENTAL.
|
||
|
package grpclb
|
||
|
|
||
|
import (
|
||
|
"errors"
|
||
|
"fmt"
|
||
|
"sync"
|
||
|
|
||
|
"golang.org/x/net/context"
|
||
|
"google.golang.org/grpc"
|
||
|
lbpb "google.golang.org/grpc/grpclb/grpc_lb_v1"
|
||
|
"google.golang.org/grpc/grpclog"
|
||
|
"google.golang.org/grpc/naming"
|
||
|
)
|
||
|
|
||
|
// Balancer creates a grpclb load balancer.
|
||
|
func Balancer(r naming.Resolver) grpc.Balancer {
|
||
|
return &balancer{
|
||
|
r: r,
|
||
|
}
|
||
|
}
|
||
|
|
||
|
type remoteBalancerInfo struct {
|
||
|
addr grpc.Address
|
||
|
name string
|
||
|
}
|
||
|
|
||
|
// addrInfo consists of the information of a backend server.
|
||
|
type addrInfo struct {
|
||
|
addr grpc.Address
|
||
|
connected bool
|
||
|
dropRequest bool
|
||
|
}
|
||
|
|
||
|
type balancer struct {
|
||
|
r naming.Resolver
|
||
|
mu sync.Mutex
|
||
|
seq int // a sequence number to make sure addrCh does not get stale addresses.
|
||
|
w naming.Watcher
|
||
|
addrCh chan []grpc.Address
|
||
|
rbs []remoteBalancerInfo
|
||
|
addrs []addrInfo
|
||
|
next int
|
||
|
waitCh chan struct{}
|
||
|
done bool
|
||
|
}
|
||
|
|
||
|
func (b *balancer) watchAddrUpdates(w naming.Watcher, ch chan remoteBalancerInfo) error {
|
||
|
updates, err := w.Next()
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
b.mu.Lock()
|
||
|
defer b.mu.Unlock()
|
||
|
if b.done {
|
||
|
return grpc.ErrClientConnClosing
|
||
|
}
|
||
|
var bAddr remoteBalancerInfo
|
||
|
if len(b.rbs) > 0 {
|
||
|
bAddr = b.rbs[0]
|
||
|
}
|
||
|
for _, update := range updates {
|
||
|
addr := grpc.Address{
|
||
|
Addr: update.Addr,
|
||
|
Metadata: update.Metadata,
|
||
|
}
|
||
|
switch update.Op {
|
||
|
case naming.Add:
|
||
|
var exist bool
|
||
|
for _, v := range b.rbs {
|
||
|
// TODO: Is the same addr with different server name a different balancer?
|
||
|
if addr == v.addr {
|
||
|
exist = true
|
||
|
break
|
||
|
}
|
||
|
}
|
||
|
if exist {
|
||
|
continue
|
||
|
}
|
||
|
b.rbs = append(b.rbs, remoteBalancerInfo{addr: addr})
|
||
|
case naming.Delete:
|
||
|
for i, v := range b.rbs {
|
||
|
if addr == v.addr {
|
||
|
copy(b.rbs[i:], b.rbs[i+1:])
|
||
|
b.rbs = b.rbs[:len(b.rbs)-1]
|
||
|
break
|
||
|
}
|
||
|
}
|
||
|
default:
|
||
|
grpclog.Println("Unknown update.Op ", update.Op)
|
||
|
}
|
||
|
}
|
||
|
// TODO: Fall back to the basic round-robin load balancing if the resulting address is
|
||
|
// not a load balancer.
|
||
|
if len(b.rbs) > 0 {
|
||
|
// For simplicity, always use the first one now. May revisit this decision later.
|
||
|
if b.rbs[0] != bAddr {
|
||
|
select {
|
||
|
case <-ch:
|
||
|
default:
|
||
|
}
|
||
|
ch <- b.rbs[0]
|
||
|
}
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (b *balancer) processServerList(l *lbpb.ServerList, seq int) {
|
||
|
servers := l.GetServers()
|
||
|
var (
|
||
|
sl []addrInfo
|
||
|
addrs []grpc.Address
|
||
|
)
|
||
|
for _, s := range servers {
|
||
|
// TODO: Support ExpirationInterval
|
||
|
addr := grpc.Address{
|
||
|
Addr: fmt.Sprintf("%s:%d", s.IpAddress, s.Port),
|
||
|
// TODO: include LoadBalanceToken in the Metadata
|
||
|
}
|
||
|
sl = append(sl, addrInfo{
|
||
|
addr: addr,
|
||
|
// TODO: Support dropRequest feature.
|
||
|
})
|
||
|
addrs = append(addrs, addr)
|
||
|
}
|
||
|
b.mu.Lock()
|
||
|
defer b.mu.Unlock()
|
||
|
if b.done || seq < b.seq {
|
||
|
return
|
||
|
}
|
||
|
if len(sl) > 0 {
|
||
|
// reset b.next to 0 when replacing the server list.
|
||
|
b.next = 0
|
||
|
b.addrs = sl
|
||
|
b.addrCh <- addrs
|
||
|
}
|
||
|
return
|
||
|
}
|
||
|
|
||
|
func (b *balancer) callRemoteBalancer(lbc lbpb.LoadBalancerClient) (retry bool) {
|
||
|
ctx, cancel := context.WithCancel(context.Background())
|
||
|
defer cancel()
|
||
|
stream, err := lbc.BalanceLoad(ctx, grpc.FailFast(false))
|
||
|
if err != nil {
|
||
|
grpclog.Printf("Failed to perform RPC to the remote balancer %v", err)
|
||
|
return
|
||
|
}
|
||
|
b.mu.Lock()
|
||
|
if b.done {
|
||
|
b.mu.Unlock()
|
||
|
return
|
||
|
}
|
||
|
b.seq++
|
||
|
seq := b.seq
|
||
|
b.mu.Unlock()
|
||
|
initReq := &lbpb.LoadBalanceRequest{
|
||
|
LoadBalanceRequestType: &lbpb.LoadBalanceRequest_InitialRequest{
|
||
|
InitialRequest: new(lbpb.InitialLoadBalanceRequest),
|
||
|
},
|
||
|
}
|
||
|
if err := stream.Send(initReq); err != nil {
|
||
|
// TODO: backoff on retry?
|
||
|
return true
|
||
|
}
|
||
|
reply, err := stream.Recv()
|
||
|
if err != nil {
|
||
|
// TODO: backoff on retry?
|
||
|
return true
|
||
|
}
|
||
|
initResp := reply.GetInitialResponse()
|
||
|
if initResp == nil {
|
||
|
grpclog.Println("Failed to receive the initial response from the remote balancer.")
|
||
|
return
|
||
|
}
|
||
|
// TODO: Support delegation.
|
||
|
if initResp.LoadBalancerDelegate != "" {
|
||
|
// delegation
|
||
|
grpclog.Println("TODO: Delegation is not supported yet.")
|
||
|
return
|
||
|
}
|
||
|
// Retrieve the server list.
|
||
|
for {
|
||
|
reply, err := stream.Recv()
|
||
|
if err != nil {
|
||
|
break
|
||
|
}
|
||
|
if serverList := reply.GetServerList(); serverList != nil {
|
||
|
b.processServerList(serverList, seq)
|
||
|
}
|
||
|
}
|
||
|
return true
|
||
|
}
|
||
|
|
||
|
func (b *balancer) Start(target string, config grpc.BalancerConfig) error {
|
||
|
// TODO: Fall back to the basic direct connection if there is no name resolver.
|
||
|
if b.r == nil {
|
||
|
return errors.New("there is no name resolver installed")
|
||
|
}
|
||
|
b.mu.Lock()
|
||
|
if b.done {
|
||
|
b.mu.Unlock()
|
||
|
return grpc.ErrClientConnClosing
|
||
|
}
|
||
|
b.addrCh = make(chan []grpc.Address)
|
||
|
w, err := b.r.Resolve(target)
|
||
|
if err != nil {
|
||
|
b.mu.Unlock()
|
||
|
return err
|
||
|
}
|
||
|
b.w = w
|
||
|
b.mu.Unlock()
|
||
|
balancerAddrCh := make(chan remoteBalancerInfo, 1)
|
||
|
// Spawn a goroutine to monitor the name resolution of remote load balancer.
|
||
|
go func() {
|
||
|
for {
|
||
|
if err := b.watchAddrUpdates(w, balancerAddrCh); err != nil {
|
||
|
grpclog.Printf("grpc: the naming watcher stops working due to %v.\n", err)
|
||
|
close(balancerAddrCh)
|
||
|
return
|
||
|
}
|
||
|
}
|
||
|
}()
|
||
|
// Spawn a goroutine to talk to the remote load balancer.
|
||
|
go func() {
|
||
|
var cc *grpc.ClientConn
|
||
|
for {
|
||
|
rb, ok := <-balancerAddrCh
|
||
|
if cc != nil {
|
||
|
cc.Close()
|
||
|
}
|
||
|
if !ok {
|
||
|
// b is closing.
|
||
|
return
|
||
|
}
|
||
|
|
||
|
// Talk to the remote load balancer to get the server list.
|
||
|
//
|
||
|
// TODO: override the server name in creds using Metadata in addr.
|
||
|
var err error
|
||
|
creds := config.DialCreds
|
||
|
if creds == nil {
|
||
|
cc, err = grpc.Dial(rb.addr.Addr, grpc.WithInsecure())
|
||
|
} else {
|
||
|
cc, err = grpc.Dial(rb.addr.Addr, grpc.WithTransportCredentials(creds))
|
||
|
}
|
||
|
if err != nil {
|
||
|
grpclog.Printf("Failed to setup a connection to the remote balancer %v: %v", rb.addr, err)
|
||
|
return
|
||
|
}
|
||
|
go func(cc *grpc.ClientConn) {
|
||
|
lbc := lbpb.NewLoadBalancerClient(cc)
|
||
|
for {
|
||
|
if retry := b.callRemoteBalancer(lbc); !retry {
|
||
|
cc.Close()
|
||
|
return
|
||
|
}
|
||
|
}
|
||
|
}(cc)
|
||
|
}
|
||
|
}()
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (b *balancer) down(addr grpc.Address, err error) {
|
||
|
b.mu.Lock()
|
||
|
defer b.mu.Unlock()
|
||
|
for _, a := range b.addrs {
|
||
|
if addr == a.addr {
|
||
|
a.connected = false
|
||
|
break
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (b *balancer) Up(addr grpc.Address) func(error) {
|
||
|
b.mu.Lock()
|
||
|
defer b.mu.Unlock()
|
||
|
if b.done {
|
||
|
return nil
|
||
|
}
|
||
|
var cnt int
|
||
|
for _, a := range b.addrs {
|
||
|
if a.addr == addr {
|
||
|
if a.connected {
|
||
|
return nil
|
||
|
}
|
||
|
a.connected = true
|
||
|
}
|
||
|
if a.connected {
|
||
|
cnt++
|
||
|
}
|
||
|
}
|
||
|
// addr is the only one which is connected. Notify the Get() callers who are blocking.
|
||
|
if cnt == 1 && b.waitCh != nil {
|
||
|
close(b.waitCh)
|
||
|
b.waitCh = nil
|
||
|
}
|
||
|
return func(err error) {
|
||
|
b.down(addr, err)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (b *balancer) Get(ctx context.Context, opts grpc.BalancerGetOptions) (addr grpc.Address, put func(), err error) {
|
||
|
var ch chan struct{}
|
||
|
b.mu.Lock()
|
||
|
if b.done {
|
||
|
b.mu.Unlock()
|
||
|
err = grpc.ErrClientConnClosing
|
||
|
return
|
||
|
}
|
||
|
|
||
|
if len(b.addrs) > 0 {
|
||
|
if b.next >= len(b.addrs) {
|
||
|
b.next = 0
|
||
|
}
|
||
|
next := b.next
|
||
|
for {
|
||
|
a := b.addrs[next]
|
||
|
next = (next + 1) % len(b.addrs)
|
||
|
if a.connected {
|
||
|
addr = a.addr
|
||
|
b.next = next
|
||
|
b.mu.Unlock()
|
||
|
return
|
||
|
}
|
||
|
if next == b.next {
|
||
|
// Has iterated all the possible address but none is connected.
|
||
|
break
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
if !opts.BlockingWait {
|
||
|
if len(b.addrs) == 0 {
|
||
|
b.mu.Unlock()
|
||
|
err = fmt.Errorf("there is no address available")
|
||
|
return
|
||
|
}
|
||
|
// Returns the next addr on b.addrs for a failfast RPC.
|
||
|
addr = b.addrs[b.next].addr
|
||
|
b.next++
|
||
|
b.mu.Unlock()
|
||
|
return
|
||
|
}
|
||
|
// Wait on b.waitCh for non-failfast RPCs.
|
||
|
if b.waitCh == nil {
|
||
|
ch = make(chan struct{})
|
||
|
b.waitCh = ch
|
||
|
} else {
|
||
|
ch = b.waitCh
|
||
|
}
|
||
|
b.mu.Unlock()
|
||
|
for {
|
||
|
select {
|
||
|
case <-ctx.Done():
|
||
|
err = ctx.Err()
|
||
|
return
|
||
|
case <-ch:
|
||
|
b.mu.Lock()
|
||
|
if b.done {
|
||
|
b.mu.Unlock()
|
||
|
err = grpc.ErrClientConnClosing
|
||
|
return
|
||
|
}
|
||
|
|
||
|
if len(b.addrs) > 0 {
|
||
|
if b.next >= len(b.addrs) {
|
||
|
b.next = 0
|
||
|
}
|
||
|
next := b.next
|
||
|
for {
|
||
|
a := b.addrs[next]
|
||
|
next = (next + 1) % len(b.addrs)
|
||
|
if a.connected {
|
||
|
addr = a.addr
|
||
|
b.next = next
|
||
|
b.mu.Unlock()
|
||
|
return
|
||
|
}
|
||
|
if next == b.next {
|
||
|
// Has iterated all the possible address but none is connected.
|
||
|
break
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
// The newly added addr got removed by Down() again.
|
||
|
if b.waitCh == nil {
|
||
|
ch = make(chan struct{})
|
||
|
b.waitCh = ch
|
||
|
} else {
|
||
|
ch = b.waitCh
|
||
|
}
|
||
|
b.mu.Unlock()
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (b *balancer) Notify() <-chan []grpc.Address {
|
||
|
return b.addrCh
|
||
|
}
|
||
|
|
||
|
func (b *balancer) Close() error {
|
||
|
b.mu.Lock()
|
||
|
defer b.mu.Unlock()
|
||
|
b.done = true
|
||
|
if b.waitCh != nil {
|
||
|
close(b.waitCh)
|
||
|
}
|
||
|
if b.addrCh != nil {
|
||
|
close(b.addrCh)
|
||
|
}
|
||
|
if b.w != nil {
|
||
|
b.w.Close()
|
||
|
}
|
||
|
return nil
|
||
|
}
|