Skip to content

Commit 594732d

Browse files
authored
fix(spanner): avoid desructive context augmentation that dropped all headers (#11659)
* fix(spanner): avoid desructive context augmentation that dropped all headers This change fixes a destructive context operation that dropped all prior headers, metadata.NewOutgoingContext sadly doesn't document that it doesn't reuse the input conext negating the expectation in Go that building from a parent context brings along the prior keys and metadata. While here, this chhange adds a regression test to ensure that in the future, any dropped or lost metadata will be reported during development. Fixes #11656 * Address review feedback+nits
1 parent 3866f33 commit 594732d

File tree

3 files changed

+216
-4
lines changed

3 files changed

+216
-4
lines changed

spanner/client_test.go

+2
Original file line numberDiff line numberDiff line change
@@ -4482,6 +4482,8 @@ func TestClient_WithCustomBatchTimeout(t *testing.T) {
44824482
}
44834483
}
44844484

4485+
var makeMockServer = NewMockedSpannerInMemTestServer
4486+
44854487
func TestClient_WithoutCustomBatchTimeout(t *testing.T) {
44864488
t.Parallel()
44874489

spanner/regression_test.go

+210
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
// Copyright 2025 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package spanner
16+
17+
import (
18+
"context"
19+
"errors"
20+
"fmt"
21+
"maps"
22+
"slices"
23+
"sort"
24+
"testing"
25+
26+
sppb "cloud.google.com/go/spanner/apiv1/spannerpb"
27+
"github.com/google/go-cmp/cmp"
28+
"google.golang.org/grpc"
29+
"google.golang.org/grpc/metadata"
30+
"google.golang.org/protobuf/types/known/structpb"
31+
32+
"cloud.google.com/go/spanner/internal/testutil"
33+
)
34+
35+
type methodAndMetadata struct {
36+
method string
37+
md metadata.MD
38+
}
39+
40+
type ourInterceptor struct {
41+
unaryHeaders []*methodAndMetadata
42+
streamHeaders []*methodAndMetadata
43+
}
44+
45+
func (oi *ourInterceptor) interceptStream(srv any, ss grpc.ServerStream, ssi *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
46+
md, ok := metadata.FromIncomingContext(ss.Context())
47+
if !ok {
48+
return errors.New("missing metadata in stream")
49+
}
50+
oi.streamHeaders = append(oi.streamHeaders, &methodAndMetadata{ssi.FullMethod, md})
51+
return handler(srv, ss)
52+
}
53+
54+
func (oi *ourInterceptor) interceptUnary(ctx context.Context, req any, usi *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
55+
md, ok := metadata.FromIncomingContext(ctx)
56+
if !ok {
57+
return nil, errors.New("missing metadata in unary")
58+
}
59+
oi.unaryHeaders = append(oi.unaryHeaders, &methodAndMetadata{usi.FullMethod, md})
60+
return handler(ctx, req)
61+
}
62+
63+
// This is a regression test to assert that all the expected headers are propagated
64+
// along to the final gRPC server avoiding scenarios where headers got dropped from a
65+
// destructive context augmentation call.
66+
// Please see https://github.com/googleapis/google-cloud-go/issues/11656
67+
func TestAllHeadersForwardedAppropriately(t *testing.T) {
68+
// 0. Turn off session multiplexing per #11308.
69+
t.Setenv("GOOGLE_CLOUD_SPANNER_MULTIPLEXED_SESSIONS", "0")
70+
71+
// 1. Set up the server interceptor that'll record and collect
72+
// all the headers that are received by the server.
73+
oint := new(ourInterceptor)
74+
sopts := []grpc.ServerOption{
75+
grpc.UnaryInterceptor(oint.interceptUnary), grpc.StreamInterceptor(oint.interceptStream),
76+
}
77+
mockedServer, clientOpts, teardown := makeMockServer(t, sopts...)
78+
defer teardown()
79+
80+
clientConfig := ClientConfig{
81+
SessionPoolConfig: SessionPoolConfig{
82+
MinOpened: 2,
83+
MaxOpened: 10,
84+
WriteSessions: 0.2,
85+
incStep: 2,
86+
},
87+
EnableEndToEndTracing: true,
88+
DisableRouteToLeader: false,
89+
}
90+
formattedDatabase := fmt.Sprintf("projects/%s/instances/%s/databases/%s", "[PROJECT]", "[INSTANCE]", "[DATABASE]")
91+
sc, err := NewClientWithConfig(context.Background(), formattedDatabase, clientConfig, clientOpts...)
92+
if err != nil {
93+
t.Fatal(err)
94+
}
95+
defer sc.Close()
96+
97+
// 2. Perform a simple "SELECT 1" to trigger both unary and streaming gRPC calls.
98+
sqlSELECT1 := "SELECT 1"
99+
resultSet := &sppb.ResultSet{
100+
Rows: []*structpb.ListValue{
101+
{Values: []*structpb.Value{
102+
{Kind: &structpb.Value_StringValue{StringValue: "1"}},
103+
}},
104+
},
105+
Metadata: &sppb.ResultSetMetadata{
106+
RowType: &sppb.StructType{
107+
Fields: []*sppb.StructType_Field{
108+
{Name: "Int", Type: &sppb.Type{Code: sppb.TypeCode_INT64}},
109+
},
110+
},
111+
},
112+
}
113+
result := &testutil.StatementResult{
114+
Type: testutil.StatementResultResultSet,
115+
ResultSet: resultSet,
116+
}
117+
mockedServer.TestSpanner.PutStatementResult(sqlSELECT1, result)
118+
119+
txn := sc.ReadOnlyTransaction()
120+
defer txn.Close()
121+
122+
ctx := context.Background()
123+
stmt := NewStatement(sqlSELECT1)
124+
rowIter := txn.Query(ctx, stmt)
125+
defer rowIter.Stop()
126+
var got []int64
127+
if err := SelectAll(rowIter, &got); err != nil {
128+
t.Fatal(err)
129+
}
130+
want := []int64{1}
131+
if diff := cmp.Diff(got, want); diff != "" {
132+
t.Fatalf("Results expectation mismatches: got - want +\n%s", diff)
133+
}
134+
135+
// 3. Now perform the assertions of expected headers.
136+
type headerExpectation struct {
137+
MethodName string
138+
WantHeaders []string
139+
}
140+
141+
wantUnaryExpectations := []*headerExpectation{
142+
{
143+
"/google.spanner.v1.Spanner/BatchCreateSessions",
144+
[]string{
145+
":authority", "content-type", "google-cloud-resource-prefix",
146+
"grpc-accept-encoding", "user-agent", "x-goog-api-client",
147+
"x-goog-request-params", "x-goog-spanner-end-to-end-tracing",
148+
"x-goog-spanner-request-id", "x-goog-spanner-route-to-leader",
149+
},
150+
},
151+
{
152+
"/google.spanner.v1.Spanner/BeginTransaction",
153+
[]string{
154+
":authority", "content-type", "google-cloud-resource-prefix",
155+
"grpc-accept-encoding", "user-agent", "x-goog-api-client",
156+
"x-goog-request-params", "x-goog-spanner-end-to-end-tracing",
157+
"x-goog-spanner-request-id",
158+
},
159+
},
160+
}
161+
162+
wantStreamingExpectations := []*headerExpectation{
163+
{
164+
"/google.spanner.v1.Spanner/ExecuteStreamingSql",
165+
[]string{
166+
":authority", "content-type", "google-cloud-resource-prefix",
167+
"grpc-accept-encoding", "user-agent", "x-goog-api-client",
168+
"x-goog-request-params", "x-goog-spanner-end-to-end-tracing",
169+
"x-goog-spanner-request-id",
170+
},
171+
},
172+
}
173+
174+
var gotUnaryExpectations []*headerExpectation
175+
for _, mdp := range oint.unaryHeaders {
176+
gotHeaderKeys := slices.Collect(maps.Keys(mdp.md))
177+
gotUnaryExpectations = append(gotUnaryExpectations, &headerExpectation{mdp.method, gotHeaderKeys})
178+
}
179+
180+
var gotStreamingExpectations []*headerExpectation
181+
for _, mdp := range oint.streamHeaders {
182+
gotHeaderKeys := slices.Collect(maps.Keys(mdp.md))
183+
gotStreamingExpectations = append(gotStreamingExpectations, &headerExpectation{mdp.method, gotHeaderKeys})
184+
}
185+
186+
sortHeaderExpectations := func(expectations []*headerExpectation) {
187+
// Firstly sort by method name.
188+
sort.Slice(expectations, func(i, j int) bool {
189+
return expectations[i].MethodName < expectations[j].MethodName
190+
})
191+
192+
// 2. Within each expectation, also then sort the header keys.
193+
for i := range expectations {
194+
exp := expectations[i]
195+
sort.Strings(exp.WantHeaders)
196+
}
197+
}
198+
199+
sortHeaderExpectations(gotUnaryExpectations)
200+
sortHeaderExpectations(wantUnaryExpectations)
201+
if diff := cmp.Diff(gotUnaryExpectations, wantUnaryExpectations); diff != "" {
202+
t.Fatalf("Unary headers mismatch: got - want +\n%s", diff)
203+
}
204+
205+
sortHeaderExpectations(gotStreamingExpectations)
206+
sortHeaderExpectations(wantStreamingExpectations)
207+
if diff := cmp.Diff(gotStreamingExpectations, wantStreamingExpectations); diff != "" {
208+
t.Fatalf("Streaming headers mismatch: got - want +\n%s", diff)
209+
}
210+
}

spanner/request_id_header.go

+4-4
Original file line numberDiff line numberDiff line change
@@ -138,9 +138,9 @@ func gRPCCallOptionsToRequestID(opts []grpc.CallOption) (md metadata.MD, reqID r
138138
func (wr *requestIDHeaderInjector) interceptUnary(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
139139
// It is imperative to search for the requestID before the call
140140
// because gRPC's internals will consume the headers.
141-
metadataWithRequestID, reqID, foundRequestID := gRPCCallOptionsToRequestID(opts)
141+
_, reqID, foundRequestID := gRPCCallOptionsToRequestID(opts)
142142
if foundRequestID {
143-
ctx = metadata.NewOutgoingContext(ctx, metadataWithRequestID)
143+
ctx = metadata.AppendToOutgoingContext(ctx, xSpannerRequestIDHeader, string(reqID))
144144
}
145145

146146
err := invoker(ctx, method, req, reply, cc, opts...)
@@ -179,9 +179,9 @@ type requestIDHeaderInjector int
179179
func (wr *requestIDHeaderInjector) interceptStream(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
180180
// It is imperative to search for the requestID before the call
181181
// because gRPC's internals will consume the headers.
182-
metadataWithRequestID, reqID, foundRequestID := gRPCCallOptionsToRequestID(opts)
182+
_, reqID, foundRequestID := gRPCCallOptionsToRequestID(opts)
183183
if foundRequestID {
184-
ctx = metadata.NewOutgoingContext(ctx, metadataWithRequestID)
184+
ctx = metadata.AppendToOutgoingContext(ctx, xSpannerRequestIDHeader, string(reqID))
185185
}
186186

187187
cs, err := streamer(ctx, desc, cc, method, opts...)

0 commit comments

Comments
 (0)