Skip to content

Commit 95c0b96

Browse files
feat: Support passing in shard metadata via data_cli flags
Bug: b/311382140 Change-Id: I0453195f6025d1db611ee17cf2a1b4a7f59813ff GitOrigin-RevId: a1c6676596b87c4428adda0eab1bc1825413f5ce
1 parent 2cad819 commit 95c0b96

7 files changed

+162
-49
lines changed

tools/data_cli/commands/BUILD.bazel

+2
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ cc_library(
3636
"//public/data_loading/readers:delta_record_stream_reader",
3737
"//public/data_loading/writers:delta_record_stream_writer",
3838
"//public/data_loading/writers:delta_record_writer",
39+
"//public/sharding:sharding_function",
3940
"@com_github_google_glog//:glog",
4041
"@com_google_absl//absl/memory",
4142
"@com_google_absl//absl/status",
@@ -71,6 +72,7 @@ cc_library(
7172
"//public/data_loading:riegeli_metadata_cc_proto",
7273
"//public/data_loading/readers:delta_record_stream_reader",
7374
"//public/data_loading/writers:snapshot_stream_writer",
75+
"//public/sharding:sharding_function",
7476
"@com_google_absl//absl/memory",
7577
"@com_google_absl//absl/status",
7678
"@com_google_absl//absl/status:statusor",

tools/data_cli/commands/format_data_command.cc

+33-7
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
#include "public/data_loading/csv/csv_delta_record_stream_writer.h"
2929
#include "public/data_loading/readers/delta_record_stream_reader.h"
3030
#include "public/data_loading/writers/delta_record_stream_writer.h"
31+
#include "public/sharding/sharding_function.h"
3132
#include "src/cpp/util/status_macro/status_macros.h"
3233

3334
namespace kv_server {
@@ -60,6 +61,14 @@ absl::Status ValidateParams(const FormatDataCommand::Params& params) {
6061
"Input and output format must be different. Input format: ",
6162
params.input_format, " Output format: ", params.output_format));
6263
}
64+
if (params.shard_number >= 0 &&
65+
params.number_of_shards <= params.shard_number) {
66+
return absl::InvalidArgumentError(absl::StrCat(
67+
"Shard metadata is invalid. shard_number is ", params.shard_number,
68+
" and number_of_shards is ", params.number_of_shards,
69+
". Valid inputs must satisfy the requirement: 0 <= shard_number < "
70+
"number_of_shards"));
71+
}
6372
return absl::OkStatus();
6473
}
6574

@@ -128,6 +137,10 @@ absl::StatusOr<std::unique_ptr<DeltaRecordWriter>> CreateRecordWriter(
128137
}
129138
if (lw_output_format == kDeltaFormat) {
130139
KVFileMetadata metadata;
140+
if (params.shard_number >= 0) {
141+
auto* shard_metadata = metadata.mutable_sharding_metadata();
142+
shard_metadata->set_shard_num(params.shard_number);
143+
}
131144
auto delta_record_writer = DeltaRecordStreamWriter<std::ostream>::Create(
132145
output_stream, DeltaRecordWriter::Options{.metadata = metadata});
133146
if (!delta_record_writer.ok()) {
@@ -142,8 +155,7 @@ absl::StatusOr<std::unique_ptr<DeltaRecordWriter>> CreateRecordWriter(
142155
} // namespace
143156

144157
absl::StatusOr<std::unique_ptr<FormatDataCommand>> FormatDataCommand::Create(
145-
const Params& params, std::istream& input_stream,
146-
std::ostream& output_stream) {
158+
Params params, std::istream& input_stream, std::ostream& output_stream) {
147159
if (absl::Status status = ValidateParams(params); !status.ok()) {
148160
return status;
149161
}
@@ -155,21 +167,35 @@ absl::StatusOr<std::unique_ptr<FormatDataCommand>> FormatDataCommand::Create(
155167
if (!record_writer.ok()) {
156168
return record_writer.status();
157169
}
158-
return absl::WrapUnique(new FormatDataCommand(std::move(*record_reader),
159-
std::move(*record_writer)));
170+
return absl::WrapUnique(new FormatDataCommand(
171+
std::move(*record_reader), std::move(*record_writer), params));
160172
}
161173

162174
absl::Status FormatDataCommand::Execute() {
163175
LOG(INFO) << "Formatting records ...";
164176
int64_t records_count = 0;
177+
ShardingFunction sharding_function(/*seed=*/"");
165178
absl::Status status = record_reader_->ReadRecords(
166-
[record_writer = record_writer_.get(),
167-
&records_count](DataRecordStruct data_record) {
179+
[&records_count, &sharding_function, this](DataRecordStruct data_record) {
180+
if (params_.shard_number >= 0 &&
181+
std::holds_alternative<KeyValueMutationRecordStruct>(
182+
data_record.record)) {
183+
KeyValueMutationRecordStruct record_struct =
184+
std::get<KeyValueMutationRecordStruct>(data_record.record);
185+
auto record_shard_num = sharding_function.GetShardNumForKey(
186+
record_struct.key, params_.number_of_shards);
187+
if (params_.shard_number != record_shard_num) {
188+
LOG(INFO) << "Skipping record with key: " << record_struct.key
189+
<< " . The record belongs to shard: " << record_shard_num
190+
<< ", but shard_number is " << params_.shard_number;
191+
return absl::OkStatus();
192+
}
193+
}
168194
records_count++;
169195
if ((double)std::rand() / RAND_MAX <= kSamplingThreshold) {
170196
LOG(INFO) << "Formatting record: " << records_count;
171197
}
172-
return record_writer->WriteRecord(data_record);
198+
return record_writer_->WriteRecord(data_record);
173199
});
174200
record_writer_->Close();
175201
LOG(INFO) << "Sucessfully formated records.";

tools/data_cli/commands/format_data_command.h

+14-10
Original file line numberDiff line numberDiff line change
@@ -55,27 +55,31 @@ namespace kv_server {
5555
class FormatDataCommand : public Command {
5656
public:
5757
struct Params {
58-
std::string_view input_format;
59-
std::string_view output_format;
60-
char csv_column_delimiter;
61-
char csv_value_delimiter;
62-
std::string_view record_type;
63-
std::string_view csv_encoding = "PLAINTEXT";
58+
std::string input_format = "CSV";
59+
std::string output_format = "DELTA";
60+
char csv_column_delimiter = ',';
61+
char csv_value_delimiter = '|';
62+
std::string record_type = "KEY_VALUE_MUTATION_RECORD";
63+
std::string csv_encoding = "PLAINTEXT";
64+
int64_t shard_number = -1;
65+
int64_t number_of_shards = -1;
6466
};
6567

6668
static absl::StatusOr<std::unique_ptr<FormatDataCommand>> Create(
67-
const Params& params, std::istream& input_stream,
68-
std::ostream& output_stream);
69+
Params params, std::istream& input_stream, std::ostream& output_stream);
6970
absl::Status Execute() override;
7071

7172
private:
7273
FormatDataCommand(std::unique_ptr<DeltaRecordReader> record_reader,
73-
std::unique_ptr<DeltaRecordWriter> record_writer)
74+
std::unique_ptr<DeltaRecordWriter> record_writer,
75+
Params params)
7476
: record_reader_(std::move(record_reader)),
75-
record_writer_(std::move(record_writer)) {}
77+
record_writer_(std::move(record_writer)),
78+
params_(std::move(params)) {}
7679

7780
std::unique_ptr<DeltaRecordReader> record_reader_;
7881
std::unique_ptr<DeltaRecordWriter> record_writer_;
82+
Params params_;
7983
};
8084

8185
} // namespace kv_server

tools/data_cli/commands/format_data_command_test.cc

+46-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ FormatDataCommand::Params GetParams(
3636
.output_format = "DELTA",
3737
.csv_column_delimiter = ',',
3838
.csv_value_delimiter = '|',
39-
.record_type = std::move(record_type),
39+
.record_type = std::string(record_type),
4040
};
4141
}
4242

@@ -416,5 +416,50 @@ TEST(FormatDataCommandTest, ValidateIncorrectOutputParams) {
416416
<< status;
417417
}
418418

419+
TEST(FormatDataCommandTest,
420+
ValidateGeneratingCsvToDeltaData_KVMutations_ShardNum) {
421+
std::stringstream csv_stream;
422+
std::stringstream delta_stream;
423+
CsvDeltaRecordStreamWriter csv_writer(csv_stream);
424+
const auto& record = GetDataRecord(GetKVMutationRecord());
425+
EXPECT_TRUE(csv_writer.WriteRecord(record).ok());
426+
EXPECT_TRUE(csv_writer.WriteRecord(record).ok());
427+
EXPECT_TRUE(csv_writer.WriteRecord(record).ok());
428+
csv_writer.Close();
429+
EXPECT_FALSE(csv_stream.str().empty());
430+
auto params = GetParams();
431+
params.shard_number = 2;
432+
params.number_of_shards = 3;
433+
auto command = FormatDataCommand::Create(params, csv_stream, delta_stream);
434+
EXPECT_TRUE(command.ok()) << command.status();
435+
EXPECT_TRUE((*command)->Execute().ok());
436+
DeltaRecordStreamReader delta_reader(delta_stream);
437+
auto metadata = delta_reader.ReadMetadata();
438+
EXPECT_TRUE(metadata.ok());
439+
EXPECT_EQ(metadata->sharding_metadata().shard_num(), 2);
440+
testing::MockFunction<absl::Status(DataRecordStruct)> record_callback;
441+
EXPECT_CALL(record_callback, Call)
442+
.Times(3)
443+
.WillRepeatedly([&record](DataRecordStruct actual_record) {
444+
EXPECT_EQ(actual_record, record);
445+
return absl::OkStatus();
446+
});
447+
EXPECT_TRUE(delta_reader.ReadRecords(record_callback.AsStdFunction()).ok());
448+
}
449+
450+
TEST(FormatDataCommandTest, ValidateIncorrectShardingMetadataParams) {
451+
std::stringstream unused_stream;
452+
auto params = GetParams();
453+
params.shard_number = 2;
454+
absl::Status status =
455+
FormatDataCommand::Create(params, unused_stream, unused_stream).status();
456+
EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument) << status;
457+
EXPECT_STREQ(status.message().data(),
458+
"Shard metadata is invalid. shard_number is 2 and "
459+
"number_of_shards is -1. Valid inputs must satisfy the "
460+
"requirement: 0 <= shard_number < number_of_shards")
461+
<< status;
462+
}
463+
419464
} // namespace
420465
} // namespace kv_server

tools/data_cli/commands/generate_snapshot_command.cc

+43-18
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
#include "public/data_loading/filename_utils.h"
3636
#include "public/data_loading/readers/delta_record_stream_reader.h"
3737
#include "public/data_loading/riegeli_metadata.pb.h"
38+
#include "public/sharding/sharding_function.h"
3839
#include "src/cpp/telemetry/telemetry_provider.h"
3940

4041
namespace kv_server {
@@ -80,6 +81,14 @@ absl::Status ValidateRequiredParams(GenerateSnapshotCommand::Params& params) {
8081
!IsDeltaFilename(params.ending_delta_file)) {
8182
return absl::InvalidArgumentError("Ending delta file is not valid.");
8283
}
84+
if (params.shard_number >= 0 &&
85+
params.number_of_shards <= params.shard_number) {
86+
return absl::InvalidArgumentError(absl::StrCat(
87+
"Shard metadata is invalid. shard_number is ", params.shard_number,
88+
" and number_of_shards is ", params.number_of_shards,
89+
". Valid inputs must satisfy the requirement: 0 <= shard_number < "
90+
"number_of_shards"));
91+
}
8392
return absl::OkStatus();
8493
}
8594

@@ -109,6 +118,10 @@ absl::StatusOr<KVFileMetadata> CreateSnapshotMetadata(
109118
auto snapshot_metadata = metadata.mutable_snapshot();
110119
*snapshot_metadata->mutable_starting_file() = params.starting_file;
111120
*snapshot_metadata->mutable_ending_delta_file() = params.ending_delta_file;
121+
if (params.shard_number >= 0) {
122+
auto* sharding_metadata = metadata.mutable_sharding_metadata();
123+
sharding_metadata->set_shard_num(params.shard_number);
124+
}
112125
return metadata;
113126
}
114127

@@ -117,6 +130,32 @@ void ResetInputStream(std::istream& istream) {
117130
istream.seekg(0, std::ios::beg);
118131
}
119132

133+
absl::Status WriteRecordsToSnapshotStream(
134+
const GenerateSnapshotCommand::Params& params,
135+
DeltaRecordStreamReader<std::istream>& record_reader,
136+
SnapshotStreamWriter<std::ostream>& snapshot_writer) {
137+
ShardingFunction sharding_function(/*seed=*/"");
138+
return record_reader.ReadRecords(
139+
[&params, &snapshot_writer,
140+
&sharding_function](DataRecordStruct data_record) {
141+
if (params.shard_number >= 0 &&
142+
std::holds_alternative<KeyValueMutationRecordStruct>(
143+
data_record.record)) {
144+
KeyValueMutationRecordStruct record_struct =
145+
std::get<KeyValueMutationRecordStruct>(data_record.record);
146+
auto record_shard_num = sharding_function.GetShardNumForKey(
147+
record_struct.key, params.number_of_shards);
148+
if (params.shard_number != record_shard_num) {
149+
LOG(INFO) << "Skipping record with key: " << record_struct.key
150+
<< " . The record belongs to shard: " << record_shard_num
151+
<< ", but shard_number is " << params.shard_number;
152+
return absl::OkStatus();
153+
}
154+
}
155+
return snapshot_writer.WriteRecord(data_record);
156+
});
157+
}
158+
120159
absl::StatusOr<std::string> WriteBaseSnapshotData(
121160
const GenerateSnapshotCommand::Params& params,
122161
BlobStorageClient& blob_client,
@@ -129,13 +168,8 @@ absl::StatusOr<std::string> WriteBaseSnapshotData(
129168
if (!metadata.ok()) {
130169
return metadata.status();
131170
}
132-
if (blob_reader->CanSeek()) {
133-
ResetInputStream(blob_reader->Stream());
134-
} else {
135-
blob_reader = blob_client.GetBlobReader(
136-
{.bucket = params.data_dir.data(), .key = params.starting_file.data()});
137-
}
138-
if (auto status = snapshot_writer.WriteRecordStream(blob_reader->Stream());
171+
if (auto status =
172+
WriteRecordsToSnapshotStream(params, record_reader, snapshot_writer);
139173
!status.ok()) {
140174
return status;
141175
}
@@ -163,17 +197,8 @@ absl::Status WriteDeltaFilesToSnapshot(
163197
auto blob_reader = blob_client.GetBlobReader(
164198
{.bucket = params.data_dir.data(), .key = delta_file});
165199
DeltaRecordStreamReader record_reader(blob_reader->Stream());
166-
auto metadata = record_reader.ReadMetadata();
167-
if (!metadata.ok()) {
168-
return metadata.status();
169-
}
170-
if (blob_reader->CanSeek()) {
171-
ResetInputStream(blob_reader->Stream());
172-
} else {
173-
blob_reader = blob_client.GetBlobReader(
174-
{.bucket = params.data_dir.data(), .key = delta_file});
175-
}
176-
if (auto status = snapshot_writer.WriteRecordStream(blob_reader->Stream());
200+
if (auto status = WriteRecordsToSnapshotStream(params, record_reader,
201+
snapshot_writer);
177202
!status.ok()) {
178203
return status;
179204
}

tools/data_cli/commands/generate_snapshot_command.h

+2
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ class GenerateSnapshotCommand : public Command {
3737
std::string ending_delta_file;
3838
std::string snapshot_file;
3939
bool in_memory_compaction;
40+
int64_t shard_number = -1;
41+
int64_t number_of_shards = -1;
4042
};
4143

4244
~GenerateSnapshotCommand();

0 commit comments

Comments
 (0)