Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions xls/scheduling/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,7 @@ cc_library(
"//xls/ir:value",
"//xls/passes:pass_base",
"//xls/solvers:z3_ir_translator",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:btree",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/log",
Expand Down Expand Up @@ -488,7 +489,9 @@ cc_test(
"//xls/ir:function_builder",
"//xls/ir:ir_matcher",
"//xls/ir:ir_test_base",
"//xls/ir:op",
"//xls/ir:proc_conversion",
"//xls/ir:source_location",
"//xls/ir:value",
"//xls/passes:cse_pass",
"//xls/passes:dce_pass",
Expand All @@ -497,6 +500,7 @@ cc_test(
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/status:status_matchers",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/types:span",
"@googletest//:gtest",
],
)
Expand Down
201 changes: 167 additions & 34 deletions xls/scheduling/proc_state_legalization_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <variant>
#include <vector>

#include "absl/algorithm/container.h"
#include "absl/container/btree_set.h"
#include "absl/container/flat_hash_set.h"
#include "absl/log/check.h"
Expand Down Expand Up @@ -53,9 +54,19 @@ namespace {
absl::StatusOr<bool> LegalizeStateReadPredicate(
Proc* proc, StateElement* state_element,
const SchedulingPassOptions& options) {
StateRead* state_read = proc->GetStateReadByStateElement(state_element);
absl::Span<StateRead* const> state_reads =
proc->GetStateReadsByStateElement(state_element);
const absl::btree_set<Next*, Node::NodeIdLessThan>& next_values =
proc->next_values(state_element);

// For now only legalize predicate if there is exactly one state_read to
// ensure a simple read and write pass. Later on we can expand this to
// handle multiple state_reads.
if (state_reads.size() != 1) {
return false;
}

StateRead* state_read = state_reads.front();
if (!state_read->predicate().has_value() || next_values.empty()) {
// Already unconditional, or no explicit `next_value`s; nothing to do.
return false;
Expand All @@ -66,7 +77,7 @@ absl::StatusOr<bool> LegalizeStateReadPredicate(
predicates.reserve(1 + next_values.size());
predicates_set.reserve(next_values.size());
for (Next* next : next_values) {
if (next->state_read() == next->value()) {
if (state_read == next->value()) {
// This is a no-op next_value; we will narrow it to the case where the
// state read is active instead.
continue;
Expand Down Expand Up @@ -133,7 +144,7 @@ absl::StatusOr<bool> LegalizeStateReadPredicates(
return changed;
}

absl::StatusOr<bool> AddMutualExclusionAssert(
absl::StatusOr<bool> AddWriteMutualExclusionAssert(
Proc* proc, StateElement* state_element,
const SchedulingPassOptions& options) {
const absl::btree_set<Next*, Node::NodeIdLessThan>& next_values =
Expand Down Expand Up @@ -195,6 +206,87 @@ absl::StatusOr<bool> AddMutualExclusionAssert(
return true;
}

absl::StatusOr<bool> AddReadMutualExclusionAssert(
Proc* proc, StateElement* state_element,
const SchedulingPassOptions& options) {
absl::Span<StateRead* const> state_reads =
proc->GetStateReadsByStateElement(state_element);
if (state_reads.size() < 2) {
return false;
}
std::string label =
absl::StrCat("__", state_element->name(), "__at_most_one_read_assert");
if (proc->HasNode(label)) {
return absl::InternalError(absl::StrFormat(
"Read mutual exclusion assert already exists for state "
"element '%s'; was this pass run twice? assert label: %s",
state_element->name(), label));
}
std::vector<Node*> predicate_list;
Node* true_lit = nullptr;
for (StateRead* state_read : state_reads) {
if (state_read->predicate().has_value()) {
predicate_list.push_back(*state_read->predicate());
} else {
if (true_lit == nullptr) {
XLS_ASSIGN_OR_RETURN(
true_lit, proc->MakeNode<Literal>(SourceInfo(), Value::Bool(true)));
}
predicate_list.push_back(true_lit);
}
}
XLS_ASSIGN_OR_RETURN(
Node * predicates,
proc->MakeNodeWithName<Concat>(
SourceInfo(), predicate_list,
absl::StrCat("__", state_element->name(), "__read_predicates")));
XLS_ASSIGN_OR_RETURN(
Node * one_hot_predicates,
proc->MakeNode<OneHot>(SourceInfo(), predicates, LsbOrMsb::kLsb));
XLS_ASSIGN_OR_RETURN(Node * at_most_one_predicate,
proc->MakeNode<BitSlice>(
SourceInfo(), one_hot_predicates, /*start=*/0,
/*width=*/one_hot_predicates->BitCountOrDie() - 1));
XLS_ASSIGN_OR_RETURN(
Node * at_most_one_read,
proc->MakeNodeWithName<CompareOp>(
SourceInfo(), predicates, at_most_one_predicate, Op::kEq,
absl::StrCat("__", state_element->name(), "__at_most_one_read")));
XLS_ASSIGN_OR_RETURN(Node * tkn,
proc->MakeNode<Literal>(SourceInfo(), Value::Token()));
XLS_RETURN_IF_ERROR(
proc->MakeNodeWithName<Assert>(
SourceInfo(), tkn,
/*condition=*/at_most_one_read,
/*message=*/
absl::StrCat("More than one StateRead active for state element: ",
state_element->name()),
/*label=*/label,
/*original_label=*/std::nullopt,
/*name=*/label)
.status());
return true;
}

absl::StatusOr<bool> AddMutualExclusionAssert(
Proc* proc, StateElement* state_element,
const SchedulingPassOptions& options) {
bool changed = false;
XLS_ASSIGN_OR_RETURN(
bool write_assert_added,
AddWriteMutualExclusionAssert(proc, state_element, options));
if (write_assert_added) {
changed = true;
}
XLS_ASSIGN_OR_RETURN(
bool read_assert_added,
AddReadMutualExclusionAssert(proc, state_element, options));
if (read_assert_added) {
changed = true;
}
return changed;
}

absl::StatusOr<bool> AddMutualExclusionAsserts(
Proc* proc, const SchedulingPassOptions& options) {
bool changed = false;
Expand All @@ -215,8 +307,12 @@ absl::StatusOr<bool> AddMutualExclusionAsserts(
absl::StatusOr<bool> AddWriteWithoutReadAsserts(
Proc* proc, StateElement* state_element,
const SchedulingPassOptions& options) {
StateRead* state_read = proc->GetStateReadByStateElement(state_element);
if (!state_read->predicate().has_value()) {
absl::Span<StateRead* const> state_reads =
proc->GetStateReadsByStateElement(state_element);

if (absl::c_any_of(state_reads, [](StateRead* state_read) {
return !state_read->predicate().has_value();
})) {
return false;
}

Expand All @@ -226,7 +322,23 @@ absl::StatusOr<bool> AddWriteWithoutReadAsserts(
return false;
}

std::vector<Node*> predicate_list;
std::vector<Node*> read_predicates;
read_predicates.reserve(state_reads.size());
for (StateRead* state_read : state_reads) {
read_predicates.push_back(*state_read->predicate());
}

Node* any_read_active;
if (state_reads.size() == 1) {
any_read_active = *state_reads[0]->predicate();
} else {
XLS_ASSIGN_OR_RETURN(
any_read_active,
proc->MakeNodeWithName<NaryOp>(
SourceInfo(), read_predicates, Op::kOr,
absl::StrCat("__", state_element->name(), "__any_read_active")));
}

for (Next* next : next_values) {
XLS_RET_CHECK(next->predicate().has_value());
XLS_ASSIGN_OR_RETURN(
Expand All @@ -239,8 +351,7 @@ absl::StatusOr<bool> AddWriteWithoutReadAsserts(
Node * no_write_without_read,
proc->MakeNodeWithName<NaryOp>(
SourceInfo(),
absl::MakeConstSpan({*state_read->predicate(), next_not_triggered}),
Op::kOr,
absl::MakeConstSpan({any_read_active, next_not_triggered}), Op::kOr,
absl::StrCat("__", state_element->name(), "__no_next_", next->id(),
"_without_read")));
std::string label = absl::StrCat("__", state_element->name(), "__next_",
Expand Down Expand Up @@ -291,7 +402,19 @@ absl::StatusOr<bool> AddDefaultNextValue(Proc* proc,
StateElement* state_element,
const SchedulingPassOptions& options) {
absl::btree_set<Node*, Node::NodeIdLessThan> predicates;
StateRead* state_read = proc->GetStateReadByStateElement(state_element);
absl::Span<StateRead* const> state_reads =
proc->GetStateReadsByStateElement(state_element);
XLS_RET_CHECK(!state_reads.empty());
StateRead* state_read = nullptr;
for (StateRead* read : state_reads) {
if (!read->predicate().has_value()) {
state_read = read;
break;
}
}
if (state_read == nullptr) {
state_read = state_reads.front();
}
for (Next* next : proc->next_values(state_element)) {
if (next->predicate().has_value()) {
predicates.insert(*next->predicate());
Expand All @@ -303,14 +426,19 @@ absl::StatusOr<bool> AddDefaultNextValue(Proc* proc,

if (predicates.empty()) {
// No explicit `next_value` node; leave the state element unchanged by
// default.
XLS_RETURN_IF_ERROR(proc->MakeNodeWithName<Next>(
state_read->loc(), /*state_read=*/state_read,
/*value=*/state_read,
/*predicate=*/state_read->predicate(),
/*label=*/std::nullopt,
absl::StrCat(state_element->name(), "_default"))
.status());
// default. Emits a Next node for each individual state read.
for (StateRead* read : state_reads) {
std::string next_name =
state_reads.size() == 1
? absl::StrCat(state_element->name(), "_default")
: absl::StrCat(state_element->name(), "_default_", read->id());
XLS_RETURN_IF_ERROR(
proc->MakeNodeWithName<Next>(read->loc(), /*state_read=*/read,
/*value=*/read,
/*predicate=*/read->predicate(),
/*label=*/std::nullopt, next_name)
.status());
}
return true;
}

Expand Down Expand Up @@ -383,26 +511,32 @@ absl::StatusOr<bool> AddDefaultNextValue(Proc* proc,
}

// Explicitly mark the state element as unchanged when no other `next_value`
// node is active.
// node is active. Emit a Next node for each individual state read.
XLS_ASSIGN_OR_RETURN(
Node * default_predicate,
NaryNorIfNeeded(proc, std::vector(predicates.begin(), predicates.end()),
/*name=*/"", state_read->loc()));
if (state_read->predicate().has_value()) {
XLS_ASSIGN_OR_RETURN(
default_predicate,
proc->MakeNode<NaryOp>(
state_read->loc(),
absl::MakeConstSpan({*state_read->predicate(), default_predicate}),
Op::kAnd));
}
XLS_RETURN_IF_ERROR(proc->MakeNodeWithName<Next>(
state_read->loc(), /*state_read=*/state_read,
/*value=*/state_read,
/*predicate=*/default_predicate,
/*label=*/std::nullopt,
absl::StrCat(state_element->name(), "_default"))
.status());
for (StateRead* read : state_reads) {
Node* specific_predicate = default_predicate;
if (read->predicate().has_value()) {
XLS_ASSIGN_OR_RETURN(
specific_predicate,
proc->MakeNode<NaryOp>(
read->loc(),
absl::MakeConstSpan({*read->predicate(), default_predicate}),
Op::kAnd));
}
std::string next_name =
state_reads.size() == 1
? absl::StrCat(state_element->name(), "_default")
: absl::StrCat(state_element->name(), "_default_", read->id());
XLS_RETURN_IF_ERROR(
proc->MakeNodeWithName<Next>(read->loc(), /*state_read=*/read,
/*value=*/read,
/*predicate=*/specific_predicate,
/*label=*/std::nullopt, next_name)
.status());
}
return true;
}

Expand Down Expand Up @@ -447,7 +581,6 @@ absl::StatusOr<bool> ProcStateLegalizationPass::RunOnFunctionBaseInternal(
if (mutex_asserts_added) {
changed = true;
}

XLS_ASSIGN_OR_RETURN(bool write_without_read_asserts_added,
AddWriteWithoutReadAsserts(proc, options));
if (write_without_read_asserts_added) {
Expand Down
57 changes: 57 additions & 0 deletions xls/scheduling/proc_state_legalization_pass_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,20 @@
#include "absl/algorithm/container.h"
#include "absl/status/status_matchers.h"
#include "absl/status/statusor.h"
#include "absl/types/span.h"
#include "xls/common/status/matchers.h"
#include "xls/common/status/ret_check.h"
#include "xls/common/status/status_macros.h"
#include "xls/ir/bits.h"
#include "xls/ir/fileno.h"
#include "xls/ir/function_builder.h"
#include "xls/ir/ir_matcher.h"
#include "xls/ir/ir_test_base.h"
#include "xls/ir/nodes.h"
#include "xls/ir/op.h"
#include "xls/ir/package.h"
#include "xls/ir/proc_conversion.h"
#include "xls/ir/source_location.h"
#include "xls/ir/value.h"
#include "xls/passes/cse_pass.h"
#include "xls/passes/dce_pass.h"
Expand Down Expand Up @@ -549,6 +553,59 @@ TEST_P(ProcStateLegalizationPassTest,
m::Literal(0))))))));
}

TEST_P(ProcStateLegalizationPassTest,
VerifyDefaultNextValueFallbackValuePicksUnconditionalRead) {
auto p = CreatePackage();
ProcBuilder pb("p", p.get());
XLS_ASSERT_OK_AND_ASSIGN(
StateElement * x_se,
pb.proc()->AppendUnreadStateElement("x", Value(UBits(0, 32))));
XLS_ASSERT_OK_AND_ASSIGN(
StateElement * y_se,
pb.proc()->AppendUnreadStateElement("y", Value(UBits(0, 32))));

// Read 1: conditional, with distinct source location
SourceInfo loc1(SourceLocation(Fileno(1), Lineno(10), Colno(5)));
BValue p_read = pb.Literal(UBits(1, 1));
pb.StateRead(x_se, p_read, /*label=*/std::nullopt, loc1);

// Read 2: unconditional, with distinct source location
SourceInfo loc2(SourceLocation(Fileno(1), Lineno(20), Colno(15)));
pb.StateRead(x_se, /*predicate=*/std::nullopt,
/*label=*/std::nullopt, loc2);

// Explicit user write gated by a dynamic condition to ensure the write
// suppression logic isn't constant-folded away by redundant node filters.
BValue p_write = pb.Eq(pb.StateRead(y_se), pb.Literal(UBits(5, 32)));
pb.Next(x_se, pb.Literal(UBits(5, 32)), p_write);

XLS_ASSERT_OK_AND_ASSIGN(Proc * proc, pb.Build());
XLS_ASSERT_OK(p->SetTop(proc));

ASSERT_THAT(Run(proc), IsOkAndHolds(true));

// Find the base NOR/NOT node generated for the write suppression tree.
Node* base_nor = nullptr;
for (Node* node : proc->nodes()) {
if (node->OpIn({Op::kNor, Op::kNot}) &&
node->operand(0) == p_write.node()) {
base_nor = node;
break;
}
}
ASSERT_NE(base_nor, nullptr);
EXPECT_EQ(base_nor->loc().ToString(), loc2.ToString());

// Verify that the independent read mutual exclusion sub-pass successfully
// generates safety checks guaranteeing at most one read fires per cycle.
std::vector<Node*> asserts;
absl::c_copy_if(proc->nodes(), std::back_inserter(asserts),
[](Node* node) { return node->Is<Assert>(); });
EXPECT_THAT(asserts, Contains(m::Assert(
_, m::Eq(m::Concat(p_read.node(), m::Literal(1)),
m::BitSlice(m::OneHot(m::Concat()))))));
}

INSTANTIATE_TEST_SUITE_P(ProcStateLegalizationPassTestSuite,
ProcStateLegalizationPassTest,
testing::Values(false, true));
Expand Down
Loading