diff --git a/xls/scheduling/BUILD b/xls/scheduling/BUILD index 864e46c299..99ac9391db 100644 --- a/xls/scheduling/BUILD +++ b/xls/scheduling/BUILD @@ -460,6 +460,7 @@ cc_library( "//xls/ir:value", "//xls/passes:pass_base", "//xls/solvers:z3_ir_translator", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:btree", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", @@ -488,7 +489,9 @@ cc_test( "//xls/ir:function_builder", "//xls/ir:ir_matcher", "//xls/ir:ir_test_base", + "//xls/ir:op", "//xls/ir:proc_conversion", + "//xls/ir:source_location", "//xls/ir:value", "//xls/passes:cse_pass", "//xls/passes:dce_pass", @@ -497,6 +500,7 @@ cc_test( "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/types:span", "@googletest//:gtest", ], ) diff --git a/xls/scheduling/proc_state_legalization_pass.cc b/xls/scheduling/proc_state_legalization_pass.cc index 857249f155..e39c1929f8 100644 --- a/xls/scheduling/proc_state_legalization_pass.cc +++ b/xls/scheduling/proc_state_legalization_pass.cc @@ -20,6 +20,7 @@ #include #include +#include "absl/algorithm/container.h" #include "absl/container/btree_set.h" #include "absl/container/flat_hash_set.h" #include "absl/log/check.h" @@ -53,9 +54,19 @@ namespace { absl::StatusOr LegalizeStateReadPredicate( Proc* proc, StateElement* state_element, const SchedulingPassOptions& options) { - StateRead* state_read = proc->GetStateReadByStateElement(state_element); + absl::Span state_reads = + proc->GetStateReadsByStateElement(state_element); const absl::btree_set& next_values = proc->next_values(state_element); + + // For now only legalize predicate if there is exactly one state_read to + // ensure a simple read and write pass. Later on we can expand this to + // handle multiple state_reads. + if (state_reads.size() != 1) { + return false; + } + + StateRead* state_read = state_reads.front(); if (!state_read->predicate().has_value() || next_values.empty()) { // Already unconditional, or no explicit `next_value`s; nothing to do. return false; @@ -66,7 +77,7 @@ absl::StatusOr LegalizeStateReadPredicate( predicates.reserve(1 + next_values.size()); predicates_set.reserve(next_values.size()); for (Next* next : next_values) { - if (next->state_read() == next->value()) { + if (state_read == next->value()) { // This is a no-op next_value; we will narrow it to the case where the // state read is active instead. continue; @@ -133,7 +144,7 @@ absl::StatusOr LegalizeStateReadPredicates( return changed; } -absl::StatusOr AddMutualExclusionAssert( +absl::StatusOr AddWriteMutualExclusionAssert( Proc* proc, StateElement* state_element, const SchedulingPassOptions& options) { const absl::btree_set& next_values = @@ -195,6 +206,87 @@ absl::StatusOr AddMutualExclusionAssert( return true; } +absl::StatusOr AddReadMutualExclusionAssert( + Proc* proc, StateElement* state_element, + const SchedulingPassOptions& options) { + absl::Span state_reads = + proc->GetStateReadsByStateElement(state_element); + if (state_reads.size() < 2) { + return false; + } + std::string label = + absl::StrCat("__", state_element->name(), "__at_most_one_read_assert"); + if (proc->HasNode(label)) { + return absl::InternalError(absl::StrFormat( + "Read mutual exclusion assert already exists for state " + "element '%s'; was this pass run twice? assert label: %s", + state_element->name(), label)); + } + std::vector predicate_list; + Node* true_lit = nullptr; + for (StateRead* state_read : state_reads) { + if (state_read->predicate().has_value()) { + predicate_list.push_back(*state_read->predicate()); + } else { + if (true_lit == nullptr) { + XLS_ASSIGN_OR_RETURN( + true_lit, proc->MakeNode(SourceInfo(), Value::Bool(true))); + } + predicate_list.push_back(true_lit); + } + } + XLS_ASSIGN_OR_RETURN( + Node * predicates, + proc->MakeNodeWithName( + SourceInfo(), predicate_list, + absl::StrCat("__", state_element->name(), "__read_predicates"))); + XLS_ASSIGN_OR_RETURN( + Node * one_hot_predicates, + proc->MakeNode(SourceInfo(), predicates, LsbOrMsb::kLsb)); + XLS_ASSIGN_OR_RETURN(Node * at_most_one_predicate, + proc->MakeNode( + SourceInfo(), one_hot_predicates, /*start=*/0, + /*width=*/one_hot_predicates->BitCountOrDie() - 1)); + XLS_ASSIGN_OR_RETURN( + Node * at_most_one_read, + proc->MakeNodeWithName( + SourceInfo(), predicates, at_most_one_predicate, Op::kEq, + absl::StrCat("__", state_element->name(), "__at_most_one_read"))); + XLS_ASSIGN_OR_RETURN(Node * tkn, + proc->MakeNode(SourceInfo(), Value::Token())); + XLS_RETURN_IF_ERROR( + proc->MakeNodeWithName( + SourceInfo(), tkn, + /*condition=*/at_most_one_read, + /*message=*/ + absl::StrCat("More than one StateRead active for state element: ", + state_element->name()), + /*label=*/label, + /*original_label=*/std::nullopt, + /*name=*/label) + .status()); + return true; +} + +absl::StatusOr AddMutualExclusionAssert( + Proc* proc, StateElement* state_element, + const SchedulingPassOptions& options) { + bool changed = false; + XLS_ASSIGN_OR_RETURN( + bool write_assert_added, + AddWriteMutualExclusionAssert(proc, state_element, options)); + if (write_assert_added) { + changed = true; + } + XLS_ASSIGN_OR_RETURN( + bool read_assert_added, + AddReadMutualExclusionAssert(proc, state_element, options)); + if (read_assert_added) { + changed = true; + } + return changed; +} + absl::StatusOr AddMutualExclusionAsserts( Proc* proc, const SchedulingPassOptions& options) { bool changed = false; @@ -215,8 +307,12 @@ absl::StatusOr AddMutualExclusionAsserts( absl::StatusOr AddWriteWithoutReadAsserts( Proc* proc, StateElement* state_element, const SchedulingPassOptions& options) { - StateRead* state_read = proc->GetStateReadByStateElement(state_element); - if (!state_read->predicate().has_value()) { + absl::Span state_reads = + proc->GetStateReadsByStateElement(state_element); + + if (absl::c_any_of(state_reads, [](StateRead* state_read) { + return !state_read->predicate().has_value(); + })) { return false; } @@ -226,7 +322,23 @@ absl::StatusOr AddWriteWithoutReadAsserts( return false; } - std::vector predicate_list; + std::vector read_predicates; + read_predicates.reserve(state_reads.size()); + for (StateRead* state_read : state_reads) { + read_predicates.push_back(*state_read->predicate()); + } + + Node* any_read_active; + if (state_reads.size() == 1) { + any_read_active = *state_reads[0]->predicate(); + } else { + XLS_ASSIGN_OR_RETURN( + any_read_active, + proc->MakeNodeWithName( + SourceInfo(), read_predicates, Op::kOr, + absl::StrCat("__", state_element->name(), "__any_read_active"))); + } + for (Next* next : next_values) { XLS_RET_CHECK(next->predicate().has_value()); XLS_ASSIGN_OR_RETURN( @@ -239,8 +351,7 @@ absl::StatusOr AddWriteWithoutReadAsserts( Node * no_write_without_read, proc->MakeNodeWithName( SourceInfo(), - absl::MakeConstSpan({*state_read->predicate(), next_not_triggered}), - Op::kOr, + absl::MakeConstSpan({any_read_active, next_not_triggered}), Op::kOr, absl::StrCat("__", state_element->name(), "__no_next_", next->id(), "_without_read"))); std::string label = absl::StrCat("__", state_element->name(), "__next_", @@ -291,7 +402,19 @@ absl::StatusOr AddDefaultNextValue(Proc* proc, StateElement* state_element, const SchedulingPassOptions& options) { absl::btree_set predicates; - StateRead* state_read = proc->GetStateReadByStateElement(state_element); + absl::Span state_reads = + proc->GetStateReadsByStateElement(state_element); + XLS_RET_CHECK(!state_reads.empty()); + StateRead* state_read = nullptr; + for (StateRead* read : state_reads) { + if (!read->predicate().has_value()) { + state_read = read; + break; + } + } + if (state_read == nullptr) { + state_read = state_reads.front(); + } for (Next* next : proc->next_values(state_element)) { if (next->predicate().has_value()) { predicates.insert(*next->predicate()); @@ -303,14 +426,19 @@ absl::StatusOr AddDefaultNextValue(Proc* proc, if (predicates.empty()) { // No explicit `next_value` node; leave the state element unchanged by - // default. - XLS_RETURN_IF_ERROR(proc->MakeNodeWithName( - state_read->loc(), /*state_read=*/state_read, - /*value=*/state_read, - /*predicate=*/state_read->predicate(), - /*label=*/std::nullopt, - absl::StrCat(state_element->name(), "_default")) - .status()); + // default. Emits a Next node for each individual state read. + for (StateRead* read : state_reads) { + std::string next_name = + state_reads.size() == 1 + ? absl::StrCat(state_element->name(), "_default") + : absl::StrCat(state_element->name(), "_default_", read->id()); + XLS_RETURN_IF_ERROR( + proc->MakeNodeWithName(read->loc(), /*state_read=*/read, + /*value=*/read, + /*predicate=*/read->predicate(), + /*label=*/std::nullopt, next_name) + .status()); + } return true; } @@ -383,26 +511,32 @@ absl::StatusOr AddDefaultNextValue(Proc* proc, } // Explicitly mark the state element as unchanged when no other `next_value` - // node is active. + // node is active. Emit a Next node for each individual state read. XLS_ASSIGN_OR_RETURN( Node * default_predicate, NaryNorIfNeeded(proc, std::vector(predicates.begin(), predicates.end()), /*name=*/"", state_read->loc())); - if (state_read->predicate().has_value()) { - XLS_ASSIGN_OR_RETURN( - default_predicate, - proc->MakeNode( - state_read->loc(), - absl::MakeConstSpan({*state_read->predicate(), default_predicate}), - Op::kAnd)); - } - XLS_RETURN_IF_ERROR(proc->MakeNodeWithName( - state_read->loc(), /*state_read=*/state_read, - /*value=*/state_read, - /*predicate=*/default_predicate, - /*label=*/std::nullopt, - absl::StrCat(state_element->name(), "_default")) - .status()); + for (StateRead* read : state_reads) { + Node* specific_predicate = default_predicate; + if (read->predicate().has_value()) { + XLS_ASSIGN_OR_RETURN( + specific_predicate, + proc->MakeNode( + read->loc(), + absl::MakeConstSpan({*read->predicate(), default_predicate}), + Op::kAnd)); + } + std::string next_name = + state_reads.size() == 1 + ? absl::StrCat(state_element->name(), "_default") + : absl::StrCat(state_element->name(), "_default_", read->id()); + XLS_RETURN_IF_ERROR( + proc->MakeNodeWithName(read->loc(), /*state_read=*/read, + /*value=*/read, + /*predicate=*/specific_predicate, + /*label=*/std::nullopt, next_name) + .status()); + } return true; } @@ -447,7 +581,6 @@ absl::StatusOr ProcStateLegalizationPass::RunOnFunctionBaseInternal( if (mutex_asserts_added) { changed = true; } - XLS_ASSIGN_OR_RETURN(bool write_without_read_asserts_added, AddWriteWithoutReadAsserts(proc, options)); if (write_without_read_asserts_added) { diff --git a/xls/scheduling/proc_state_legalization_pass_test.cc b/xls/scheduling/proc_state_legalization_pass_test.cc index 668a585827..9ef5a9ee5f 100644 --- a/xls/scheduling/proc_state_legalization_pass_test.cc +++ b/xls/scheduling/proc_state_legalization_pass_test.cc @@ -23,16 +23,20 @@ #include "absl/algorithm/container.h" #include "absl/status/status_matchers.h" #include "absl/status/statusor.h" +#include "absl/types/span.h" #include "xls/common/status/matchers.h" #include "xls/common/status/ret_check.h" #include "xls/common/status/status_macros.h" #include "xls/ir/bits.h" +#include "xls/ir/fileno.h" #include "xls/ir/function_builder.h" #include "xls/ir/ir_matcher.h" #include "xls/ir/ir_test_base.h" #include "xls/ir/nodes.h" +#include "xls/ir/op.h" #include "xls/ir/package.h" #include "xls/ir/proc_conversion.h" +#include "xls/ir/source_location.h" #include "xls/ir/value.h" #include "xls/passes/cse_pass.h" #include "xls/passes/dce_pass.h" @@ -549,6 +553,59 @@ TEST_P(ProcStateLegalizationPassTest, m::Literal(0)))))))); } +TEST_P(ProcStateLegalizationPassTest, + VerifyDefaultNextValueFallbackValuePicksUnconditionalRead) { + auto p = CreatePackage(); + ProcBuilder pb("p", p.get()); + XLS_ASSERT_OK_AND_ASSIGN( + StateElement * x_se, + pb.proc()->AppendUnreadStateElement("x", Value(UBits(0, 32)))); + XLS_ASSERT_OK_AND_ASSIGN( + StateElement * y_se, + pb.proc()->AppendUnreadStateElement("y", Value(UBits(0, 32)))); + + // Read 1: conditional, with distinct source location + SourceInfo loc1(SourceLocation(Fileno(1), Lineno(10), Colno(5))); + BValue p_read = pb.Literal(UBits(1, 1)); + pb.StateRead(x_se, p_read, /*label=*/std::nullopt, loc1); + + // Read 2: unconditional, with distinct source location + SourceInfo loc2(SourceLocation(Fileno(1), Lineno(20), Colno(15))); + pb.StateRead(x_se, /*predicate=*/std::nullopt, + /*label=*/std::nullopt, loc2); + + // Explicit user write gated by a dynamic condition to ensure the write + // suppression logic isn't constant-folded away by redundant node filters. + BValue p_write = pb.Eq(pb.StateRead(y_se), pb.Literal(UBits(5, 32))); + pb.Next(x_se, pb.Literal(UBits(5, 32)), p_write); + + XLS_ASSERT_OK_AND_ASSIGN(Proc * proc, pb.Build()); + XLS_ASSERT_OK(p->SetTop(proc)); + + ASSERT_THAT(Run(proc), IsOkAndHolds(true)); + + // Find the base NOR/NOT node generated for the write suppression tree. + Node* base_nor = nullptr; + for (Node* node : proc->nodes()) { + if (node->OpIn({Op::kNor, Op::kNot}) && + node->operand(0) == p_write.node()) { + base_nor = node; + break; + } + } + ASSERT_NE(base_nor, nullptr); + EXPECT_EQ(base_nor->loc().ToString(), loc2.ToString()); + + // Verify that the independent read mutual exclusion sub-pass successfully + // generates safety checks guaranteeing at most one read fires per cycle. + std::vector asserts; + absl::c_copy_if(proc->nodes(), std::back_inserter(asserts), + [](Node* node) { return node->Is(); }); + EXPECT_THAT(asserts, Contains(m::Assert( + _, m::Eq(m::Concat(p_read.node(), m::Literal(1)), + m::BitSlice(m::OneHot(m::Concat())))))); +} + INSTANTIATE_TEST_SUITE_P(ProcStateLegalizationPassTestSuite, ProcStateLegalizationPassTest, testing::Values(false, true));