Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions domain_tests/arbitrary_domains_protobuf_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -782,5 +782,53 @@ TEST(ProtocolBuffer, MutationInParallelIsEfficient) {
std::cout << "ratio: " << multi_thread_time / single_thread_time << "\n";
}

TEST(ProtobufDomainTest, CorpusExceedingSoftMaxSizeIsValidWhileHardMaxIsNot) {
auto domain_with_soft_max_size =
Arbitrary<TestProtobuf>().WithRepeatedFieldsSoftMaxSize(100);
auto domain_with_hard_max_size =
Arbitrary<TestProtobuf>().WithRepeatedFieldsMaxSize(100);
TestProtobuf message;
for (int i = 0; i < 200; ++i) {
message.add_rep_i32(i);
}
auto corpus_with_soft_max_size = domain_with_soft_max_size.FromValue(message);
ASSERT_TRUE(corpus_with_soft_max_size.has_value());
EXPECT_TRUE(
domain_with_soft_max_size.ValidateCorpusValue(*corpus_with_soft_max_size)
.ok());

auto corpus_with_hard_max_size = domain_with_hard_max_size.FromValue(message);
ASSERT_TRUE(corpus_with_hard_max_size.has_value());
EXPECT_FALSE(
domain_with_hard_max_size.ValidateCorpusValue(*corpus_with_hard_max_size)
.ok());
}

TEST(ProtobufDomainTest,
NestedCorpusExceedingSoftMaxSizeIsValidAndMutationIsBounded) {
auto domain = Arbitrary<TestProtobuf>().WithRepeatedFieldsSoftMaxSize(100);
TestProtobuf message;
for (int i = 0; i < 200; ++i) {
message.add_rep_i32(i);
}
auto* subproto = message.add_rep_subproto();
for (int i = 0; i < 200; ++i) {
subproto->add_subproto_rep_i32(i);
}
auto corpus = domain.FromValue(message);
ASSERT_TRUE(corpus.has_value());
EXPECT_TRUE(domain.ValidateCorpusValue(*corpus).ok());

absl::BitGen bitgen;
for (int i = 0; i < 1000; ++i) {
domain.Mutate(*corpus, bitgen, {}, /*only_shrink=*/false);
auto mutated_message = domain.GetValue(*corpus);
EXPECT_LE(mutated_message.rep_i32_size(), 200);
if (mutated_message.rep_subproto_size() > 0) {
EXPECT_LE(mutated_message.rep_subproto(0).subproto_rep_i32_size(), 200);
}
}
}

} // namespace
} // namespace fuzztest
21 changes: 17 additions & 4 deletions fuzztest/internal/domains/container_of_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,15 @@ class ContainerOfImplBase
const domain_implementor::MutationMetadata& metadata,
bool only_shrink) {
permanent_dict_candidate_ = std::nullopt;
FUZZTEST_CHECK(min_size() <= val.size() && val.size() <= max_size())
<< "Size " << val.size() << " is not between " << min_size() << " and "
<< max_size();
if (!validate_max_size()) {
FUZZTEST_CHECK(min_size() <= val.size())
<< "Size " << val.size() << " is smaller than min size "
<< min_size();
} else {
FUZZTEST_CHECK(min_size() <= val.size() && val.size() <= max_size())
<< "Size " << val.size() << " is not between " << min_size()
<< " and " << max_size();
}

const bool can_shrink = val.size() > min_size();
const bool can_grow = !only_shrink && val.size() < max_size();
Expand Down Expand Up @@ -250,6 +256,10 @@ class ContainerOfImplBase
manual_dict_provider_ = std::move(manual_dict_provider);
return Self();
}
Derived& WithoutMaxSizeValidation() {
validate_max_size_ = false;
return Self();
}

auto GetPrinter() const {
if constexpr (std::is_same_v<value_type, std::string> ||
Expand Down Expand Up @@ -327,7 +337,7 @@ class ContainerOfImplBase
return absl::InvalidArgumentError(absl::StrCat(
"Invalid size: ", corpus_value.size(), ". Min size: ", min_size()));
}
if (corpus_value.size() > max_size()) {
if (validate_max_size() && corpus_value.size() > max_size()) {
return absl::InvalidArgumentError(absl::StrCat(
"Invalid size: ", corpus_value.size(), ". Max size: ", max_size()));
}
Expand Down Expand Up @@ -355,6 +365,7 @@ class ContainerOfImplBase
OtherInnerDomain>& other) {
min_size_ = other.min_size_;
max_size_ = other.max_size_;
validate_max_size_ = other.validate_max_size_;
}

protected:
Expand All @@ -379,6 +390,7 @@ class ContainerOfImplBase
size_t max_size() const {
return max_size_.value_or(std::max(min_size_, kDefaultContainerMaxSize));
}
bool validate_max_size() const { return validate_max_size_; }

private:
Derived& Self() { return static_cast<Derived&>(*this); }
Expand All @@ -395,6 +407,7 @@ class ContainerOfImplBase
// DO NOT use directly. Use min_size() and max_size() instead.
size_t min_size_ = 0;
std::optional<size_t> max_size_ = std::nullopt;
bool validate_max_size_ = true;

// Temporary memory dictionary. Collected from tracing the program
// execution. It will always be empty if no execution_coverage_ is found,
Expand Down
60 changes: 56 additions & 4 deletions fuzztest/internal/domains/protobuf_domain_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,11 @@ class ProtoPolicy {
{/*filter=*/std::move(filter), /*value=*/max_size});
}

void DisableMaxSizeValidationForRepeatedFields(Filter filter) {
disable_max_size_validation_for_repeated_fields_.push_back(
{/*filter=*/std::move(filter), /*value=*/true});
}

OptionalPolicy GetOptionalPolicy(const FieldDescriptor* field) const {
FUZZTEST_CHECK(!field->is_required() && !field->is_repeated())
<< "GetOptionalPolicy should apply to optional fields only!";
Expand Down Expand Up @@ -318,6 +323,15 @@ class ProtoPolicy {
return max;
}

bool ShouldValidateMaxSize(const FieldDescriptor* repeated_field) const {
FUZZTEST_CHECK(repeated_field->is_repeated())
<< "ShouldValidateMaxSize should apply to repeated "
"fields only!";
return !GetPolicyValue(disable_max_size_validation_for_repeated_fields_,
repeated_field)
.value_or(false);
}

std::optional<bool> IsFieldFinitelyRecursive(const FieldDescriptor* field) {
return caches_->IsFieldFinitelyRecursive(field);
}
Expand Down Expand Up @@ -490,6 +504,8 @@ class ProtoPolicy {
std::vector<FilterToValue<OptionalPolicy>> optional_policies_;
std::vector<FilterToValue<int64_t>> min_repeated_fields_sizes_;
std::vector<FilterToValue<int64_t>> max_repeated_fields_sizes_;
std::vector<FilterToValue<bool>>
disable_max_size_validation_for_repeated_fields_;

#define FUZZTEST_INTERNAL_POLICY_MEMBERS(Camel, cpp) \
private: \
Expand Down Expand Up @@ -918,6 +934,21 @@ class ProtobufDomainUntypedImpl
return std::move(*this);
}

ProtobufDomainUntypedImpl&& WithRepeatedFieldsSoftMaxSize(
int64_t max_size) && {
policy_.SetMaxRepeatedFieldsSize(IncludeAll<FieldDescriptor>(), max_size);
policy_.DisableMaxSizeValidationForRepeatedFields(
IncludeAll<FieldDescriptor>());
return std::move(*this);
}

ProtobufDomainUntypedImpl&& WithRepeatedFieldsSoftMaxSize(
std::function<bool(const FieldDescriptor*)> filter, int64_t max_size) && {
policy_.SetMaxRepeatedFieldsSize(filter, max_size);
policy_.DisableMaxSizeValidationForRepeatedFields(std::move(filter));
return std::move(*this);
}

#define FUZZTEST_INTERNAL_WITH_FIELD(Camel, cpp, TAG) \
using Camel##type = MakeDependentType<cpp, Message>; \
ProtobufDomainUntypedImpl&& With##Camel##Fields( \
Expand Down Expand Up @@ -1626,7 +1657,8 @@ class ProtobufDomainUntypedImpl
return ModifyDomainForRepeatedFieldRule(
std::move(domain),
use_policy ? policy_.GetMinRepeatedFieldSize(field) : std::nullopt,
use_policy ? policy_.GetMaxRepeatedFieldSize(field) : std::nullopt);
use_policy ? policy_.GetMaxRepeatedFieldSize(field) : std::nullopt,
use_policy ? policy_.ShouldValidateMaxSize(field) : true);
} else if (IsRequired(field)) {
return ModifyDomainForRequiredFieldRule(std::move(domain));
} else {
Expand Down Expand Up @@ -1687,16 +1719,20 @@ class ProtobufDomainUntypedImpl

// Simple wrapper that converts a Domain<T> into a Domain<vector<T>>.
template <typename T>
static auto ModifyDomainForRepeatedFieldRule(
const Domain<T>& d, std::optional<int64_t> min_size,
std::optional<int64_t> max_size) {
static auto ModifyDomainForRepeatedFieldRule(const Domain<T>& d,
std::optional<int64_t> min_size,
std::optional<int64_t> max_size,
bool validate_max_size) {
auto result = ContainerOfImpl<std::vector<T>, Domain<T>>(d);
if (min_size.has_value()) {
result.WithMinSize(*min_size);
}
if (max_size.has_value()) {
result.WithMaxSize(*max_size);
}
if (!validate_max_size) {
result.WithoutMaxSizeValidation();
}
return result;
}

Expand Down Expand Up @@ -2159,6 +2195,22 @@ class ProtobufDomainImpl
return std::move(*this);
}

ProtobufDomainImpl&& WithRepeatedFieldsSoftMaxSize(int64_t max_size) && {
inner_.GetPolicy().SetMaxRepeatedFieldsSize(IncludeAll<FieldDescriptor>(),
max_size);
inner_.GetPolicy().DisableMaxSizeValidationForRepeatedFields(
IncludeAll<FieldDescriptor>());
return std::move(*this);
}

ProtobufDomainImpl&& WithRepeatedFieldsSoftMaxSize(
std::function<bool(const FieldDescriptor*)> filter, int64_t max_size) && {
inner_.GetPolicy().SetMaxRepeatedFieldsSize(filter, max_size);
inner_.GetPolicy().DisableMaxSizeValidationForRepeatedFields(
std::move(filter));
return std::move(*this);
}

ProtobufDomainImpl&& WithFieldUnset(absl::string_view field) && {
inner_.WithFieldNullness(field, OptionalPolicy::kAlwaysNull);
return std::move(*this);
Expand Down
Loading