Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 106 additions & 4 deletions mlir/include/mlir/Dialect/DXSA/IR/DXSAOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,112 @@ include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/BuiltinAttributeInterfaces.td"
include "mlir/IR/EnumAttr.td"

//===----------------------------------------------------------------------===//
// DXSA op base class
//===----------------------------------------------------------------------===//

// Base class for all operations in this dialect.
class DXSA_Op<string mnemonic, list<Trait> traits = []> :
Op<DXSADialect, mnemonic, traits>;

//===----------------------------------------------------------------------===//
// DXSA module — top-level container op for a DXBC tokenized program
//===----------------------------------------------------------------------===//

def DXSA_ProgramType_PixelShader : I32EnumAttrCase<"pixel_shader", 0>;
def DXSA_ProgramType_VertexShader : I32EnumAttrCase<"vertex_shader", 1>;
def DXSA_ProgramType_GeometryShader : I32EnumAttrCase<"geometry_shader", 2>;
def DXSA_ProgramType_HullShader : I32EnumAttrCase<"hull_shader", 3>;
def DXSA_ProgramType_DomainShader : I32EnumAttrCase<"domain_shader", 4>;
def DXSA_ProgramType_ComputeShader : I32EnumAttrCase<"compute_shader", 5>;
def DXSA_ProgramType_MeshShader : I32EnumAttrCase<"mesh_shader", 13>;
def DXSA_ProgramType_AmplificationShader : I32EnumAttrCase<"amplification_shader",14>;

def DXSA_ProgramType : I32EnumAttr<
"ProgramType", "DXBC tokenized program type", [
DXSA_ProgramType_PixelShader,
DXSA_ProgramType_VertexShader,
DXSA_ProgramType_GeometryShader,
DXSA_ProgramType_HullShader,
DXSA_ProgramType_DomainShader,
DXSA_ProgramType_ComputeShader,
DXSA_ProgramType_MeshShader,
DXSA_ProgramType_AmplificationShader
]> {
let cppNamespace = "::mlir::dxsa";
let genSpecializedAttr = 0;
}

def DXSA_ProgramTypeAttr :
EnumAttr<DXSADialect, DXSA_ProgramType, "program_type"> {
let assemblyFormat = "$value";
}

def DXSA_ShaderVersionAttr : AttrDef<DXSADialect, "ShaderVersion"> {
let mnemonic = "shader_version";
let summary = "DXBC shader version (major.minor)";
let description = [{
The `#dxsa.shader_version` attribute holds the major and minor
components of shader model version.

Example:

```mlir
#dxsa.shader_version<5, 0>
```
}];
let parameters = (ins "uint8_t":$major, "uint8_t":$minor);
let assemblyFormat = "`<` $major `,` $minor `>`";
}

def DXSA_ModuleOp : DXSA_Op<"module", [
IsolatedFromAbove, NoRegionArguments, NoTerminator, SingleBlock]> {
let summary = "the top-level container for a shader program";
let description = [{
The `dxsa.module` operation models the top-level container of a single
shader tokenized program (one SHEX section of the DXBC binary).

The two optional attributes are shader program type and version.
Both attributes are either both present (real binary with a SHEX
header) or both absent (header-less raw token streams).

Example:

```mlir
// Binary content with a SHEX header
dxsa.module pixel_shader 5 0 {
dxsa.dcl_global_flags <refactoringAllowed>
}

// Binary content without a SHEX header
dxsa.module {
dxsa.dcl_global_flags <refactoringAllowed>
}
```
}];

let arguments = (ins
OptionalAttr<DXSA_ProgramTypeAttr>:$program_type,
OptionalAttr<DXSA_ShaderVersionAttr>:$shader_version);
let regions = (region SizedRegion<1>:$body);

let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;

let skipDefaultBuilders = 1;
let builders = [
OpBuilder<(ins
CArg<"::mlir::dxsa::ProgramTypeAttr",
"::mlir::dxsa::ProgramTypeAttr()">:$programType,
CArg<"::mlir::dxsa::ShaderVersionAttr",
"::mlir::dxsa::ShaderVersionAttr()">:$shaderVersion)>
];

let extraClassDeclaration = [{
::mlir::Block *getBodyBlock() { return &getBody().front(); }
}];
}

//===----------------------------------------------------------------------===//
// DXSA enum definitions
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -313,10 +419,6 @@ def DXSA_UIntNonZero : AttrConstraint<
// DXSA op definitions
//===----------------------------------------------------------------------===//

// Base class for the operation in this dialect
class DXSA_Op<string mnemonic, list<Trait> traits = []> :
Op<DXSADialect, mnemonic, traits>;

def DXSA_Operand : DXSA_Op<"operand"> {
let summary = "defines an operand of an instruction";
let description = [{
Expand Down
12 changes: 7 additions & 5 deletions mlir/include/mlir/Target/DXSA/BinaryParser.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,20 @@
#ifndef MLIR_TARGET_DXSA_BINARYPARSER_H
#define MLIR_TARGET_DXSA_BINARYPARSER_H

#include "mlir/Dialect/DXSA/IR/DXSA.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/OwningOpRef.h"
#include "llvm/Support/SourceMgr.h"

namespace mlir::dxsa {
/// Deserializes the given binary \p source and creates a MLIR ModuleOp in the
/// given \p context.
OwningOpRef<dxsa::ModuleOp> deserialize(llvm::SourceMgr &source,
MLIRContext *context);

/// Decode DXSA binary \p source and return an MLIR module.
OwningOpRef<ModuleOp> importDxsaBinaryToModule(llvm::SourceMgr &source,
MLIRContext *context);
/// Encode \p source to DXSA binary.
LogicalResult exportModuleToDxsaBinary(ModuleOp source, raw_ostream &output);
/// Serializes the given MLIR \p moduleOp and writes to \p output.
LogicalResult serialize(mlir::ModuleOp moduleOp, raw_ostream &output);
} // namespace mlir::dxsa

#endif // MLIR_TARGET_DXSA_BINARYPARSER_H
69 changes: 69 additions & 0 deletions mlir/lib/Dialect/DXSA/IR/DXSA.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#include "mlir/IR/Builders.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/OpImplementation.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/TypeSwitch.h"

Expand Down Expand Up @@ -41,6 +42,74 @@ void DXSADialect::initialize() {
#define GET_OP_CLASSES
#include "mlir/Dialect/DXSA/IR/DXSAOps.cpp.inc"

//===----------------------------------------------------------------------===//
// ModuleOp
//===----------------------------------------------------------------------===//

void ModuleOp::build(OpBuilder &builder, OperationState &state,
ProgramTypeAttr programType,
ShaderVersionAttr shaderVersion) {
if (programType)
state.addAttribute("program_type", programType);
if (shaderVersion)
state.addAttribute("shader_version", shaderVersion);
OpBuilder::InsertionGuard guard(builder);
builder.createBlock(state.addRegion());
}

ParseResult ModuleOp::parse(OpAsmParser &parser, OperationState &result) {
// Parse optional shader information like `pixel_shader 5 0`.
StringRef typeKeyword;
auto typeLoc = parser.getCurrentLocation();
if (succeeded(parser.parseOptionalKeyword(&typeKeyword))) {
auto programType = symbolizeProgramType(typeKeyword);
if (!programType)
return parser.emitError(typeLoc)
<< "unknown program type: " << typeKeyword;
result.addAttribute("program_type", ProgramTypeAttr::get(
parser.getContext(), *programType));

uint8_t major = 0, minor = 0;
if (parser.parseInteger(major) || parser.parseInteger(minor))
Comment thread
tagolog marked this conversation as resolved.
return failure();
result.addAttribute(
"shader_version",
ShaderVersionAttr::get(parser.getContext(), major, minor));
}

Region *body = result.addRegion();
if (parser.parseOptionalAttrDictWithKeyword(result.attributes) ||
parser.parseRegion(*body, /*arguments=*/{}))
return failure();

if (body->empty())
body->push_back(new Block());

return success();
}

void ModuleOp::print(OpAsmPrinter &printer) {
if (auto programType = getProgramType()) {
printer << ' ' << stringifyProgramType(*programType);
auto version = getShaderVersionAttr();
printer << ' ' << static_cast<unsigned>(version.getMajor()) << ' '
<< static_cast<unsigned>(version.getMinor());
}
printer.printOptionalAttrDictWithKeyword((*this)->getAttrs(),
{"program_type", "shader_version"});
printer << ' ';
printer.printRegion(getBody());
}

LogicalResult ModuleOp::verify() {
bool hasType = static_cast<bool>(getProgramTypeAttr());
bool hasVersion = static_cast<bool>(getShaderVersionAttr());
if (hasType != hasVersion)
return emitOpError(
"program_type and shader_version must both be present or both absent");
return success();
}

//===----------------------------------------------------------------------===//
// Op verifiers
//===----------------------------------------------------------------------===//
Expand Down
80 changes: 64 additions & 16 deletions mlir/lib/Target/DXSA/BinaryParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -420,15 +420,23 @@ static dxsa::ComponentMask decodeComponentMask(uint32_t rawComponentMask) {

class DXBuilder {
public:
DXBuilder(MLIRContext *context, StringAttr name)
: context(context),
module(ModuleOp::create(builder, FileLineColLoc::get(name, 0, 0))),
builder(module.getRegion()) {}
explicit DXBuilder(MLIRContext *context)
: context(context), builder(context) {}

using Index = mlir::Value;
using Operand = mlir::Value;
using Instruction = mlir::Operation *;
using Module = mlir::ModuleOp;
using Module = mlir::dxsa::ModuleOp;

Module createModule(mlir::dxsa::ProgramTypeAttr programType,
mlir::dxsa::ShaderVersionAttr shaderVersion,
Location loc) {
OperationState state(loc, Module::getOperationName());
Module::build(builder, state, programType, shaderVersion);
auto module = cast<Module>(Operation::create(state));
builder.setInsertionPointToStart(&module.getBody().front());
return module;
}

Index buildIndexImm32(int32_t imm, FileLineColLoc loc) {
Operation *op =
Expand Down Expand Up @@ -523,10 +531,6 @@ class DXBuilder {
builder.getStringAttr(name));
}

Module buildModule(ArrayRef<Instruction> instructions, FileLineColLoc loc) {
return module;
}

Instruction buildDclGlobalFlags(dxsa::GlobalFlags flags, Location loc) {
auto flagsAttr = dxsa::GlobalFlagsAttr::get(builder.getContext(), flags);
return dxsa::DclGlobalFlags::create(builder, loc, flagsAttr);
Expand Down Expand Up @@ -706,7 +710,6 @@ class DXBuilder {

private:
MLIRContext *context;
ModuleOp module;
OpBuilder builder;
};

Expand Down Expand Up @@ -1530,15 +1533,60 @@ class Parser {

FailureOr<Module> parseModule() {
FileLineColLoc loc = getLocation(0);
std::vector<Instruction> instructions;
auto header = parseProgramHeader();
FAILURE_IF_FAILED(header);
mlir::dxsa::ProgramTypeAttr programType;
mlir::dxsa::ShaderVersionAttr shaderVersion;
if (*header) {
programType =
mlir::dxsa::ProgramTypeAttr::get(name.getContext(), (*header)->type);
shaderVersion = mlir::dxsa::ShaderVersionAttr::get(
name.getContext(), (*header)->major, (*header)->minor);
}
auto module = builder.createModule(programType, shaderVersion, loc);
while (currentTokenOffset < buffer.size()) {
FailureOr<Instruction> inst = parseInstruction();
if (failed(inst)) {
return failure();
}
instructions.push_back(*inst);
}
return builder.buildModule(instructions, loc);
return module;
}

struct ProgramHeader {
mlir::dxsa::ProgramType type;
uint8_t major;
uint8_t minor;
};

/// If the buffer begins with a tokenized-program header (VersionToken +
/// LengthToken), decode and consume both tokens and return the program type
/// and shader model. Otherwise return without touching the parser current
/// position.
FailureOr<std::optional<ProgramHeader>> parseProgramHeader() {
constexpr size_t headerSize = 2 * sizeof(uint32_t);
if (currentTokenOffset + headerSize > buffer.size())
return std::optional<ProgramHeader>{};

auto versionToken = support::endian::read<uint32_t>(
buffer.begin() + currentTokenOffset, endianness::little);
if (DECODE_D3D10_SB_TOKENIZED_INSTRUCTION_LENGTH(versionToken) != 0)
return std::optional<ProgramHeader>{};

auto rawType = static_cast<uint32_t>(
DECODE_D3D10_SB_TOKENIZED_PROGRAM_TYPE(versionToken));
auto programType = dxsa::symbolizeProgramType(rawType);
if (!programType)
return std::optional<ProgramHeader>{};

auto major = static_cast<uint8_t>(
DECODE_D3D10_SB_TOKENIZED_PROGRAM_MAJOR_VERSION(versionToken));
auto minor = static_cast<uint8_t>(
DECODE_D3D10_SB_TOKENIZED_PROGRAM_MINOR_VERSION(versionToken));

FAILURE_IF_FAILED(parseToken()); // VersionToken
FAILURE_IF_FAILED(parseToken()); // LengthToken
return std::optional<ProgramHeader>{{*programType, major, minor}};
}

LogicalResult verifyInstructionLength(size_t beginOffset, uint32_t length) {
Expand All @@ -1558,8 +1606,8 @@ class Parser {
};

namespace mlir::dxsa {
OwningOpRef<ModuleOp> importDxsaBinaryToModule(llvm::SourceMgr &source,
MLIRContext *context) {
OwningOpRef<ModuleOp> deserialize(llvm::SourceMgr &source,
MLIRContext *context) {

if (source.getNumBuffers() != 1) {
emitError(UnknownLoc::get(context), "one source file should be provided");
Expand All @@ -1575,7 +1623,7 @@ OwningOpRef<ModuleOp> importDxsaBinaryToModule(llvm::SourceMgr &source,
context->allowUnregisteredDialects();
context->loadAllAvailableDialects();

DXBuilder builder(context, name);
DXBuilder builder(context);
Parser parser(builder, name, buffer);
FailureOr<ModuleOp> mod = parser.parseModule();
if (failed(mod))
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Target/DXSA/BinaryWriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ using namespace mlir;
using namespace llvm;

namespace mlir::dxsa {
LogicalResult exportModuleToDxsaBinary(ModuleOp source, raw_ostream &output) {
LogicalResult serialize(mlir::ModuleOp source, raw_ostream &output) {
Region &region = source.getRegion();
assert(region.hasOneBlock() && "invalid module");
return failure();
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Target/DXSA/TranslateRegistration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ void registerFromDxsaBinTranslation() {
"import-dxsa-bin", "Translate DXSA binary to MLIR",
[](llvm::SourceMgr &sourceMgr,
MLIRContext *context) -> OwningOpRef<Operation *> {
return dxsa::importDxsaBinaryToModule(sourceMgr, context);
return dxsa::deserialize(sourceMgr, context);
},
[](DialectRegistry &registry) { registry.insert<dxsa::DXSADialect>(); }};
}
Expand All @@ -28,7 +28,7 @@ void registerToDxsaBinTranslation() {
TranslateFromMLIRRegistration registration{
"export-dxsa-bin", "Translate MLIR to DXSA binary",
[](ModuleOp source, raw_ostream &output) {
return dxsa::exportModuleToDxsaBinary(source, output);
return dxsa::serialize(source, output);
},
[](DialectRegistry &registry) { registry.insert<dxsa::DXSADialect>(); }};
}
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Target/DXSA/dcl_constant_buffer.mlir
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// RUN: mlir-translate --import-dxsa-bin %S/inputs/dcl_constant_buffer.bin | FileCheck %s

// CHECK: module {
// CHECK: dxsa.module {
// CHECK-NEXT: dxsa.dcl_constant_buffer <id = 0, size = 1>, <immediateIndexed>
// CHECK-NEXT: dxsa.dcl_constant_buffer <id = 0, size = 4, lbound = 0, ubound = 3, space = 1>, <dynamicIndexed>
// CHECK-NEXT: }
Loading