Skip to main content

Cast.h File

Cast Node — CVU kernel that converts between FP32 and BF16 element-wise. More...

Included Headers

#include "builder/Node.h" #include "builder/NodeContractConfigurable.h" #include "builder/NodeContractProvider.h" #include "builder/OutputSpec.h" #include <memory> #include <string> #include <vector>

Namespaces Index

namespacesimaai
namespaceneat
namespacenodes

Classes Index

structCastOptions

Construction options for a Cast Node. More...

classCast

CVU kernel Node that casts a tensor between FP32 and BF16 (no scale/zero-point). More...

Description

Cast Node — CVU kernel that converts between FP32 and BF16 element-wise.

Pure dtype conversion with no scale/zero-point — distinct from Quant/Dequant. Inserted at the MLA boundary on the BF16 path: before the MLA when the model expects BF16 input but the upstream stage emits FP32, or after the MLA on the BF16 output path.

See Also

"The dtype contract" page in /concepts/dtype_contract

File Listing

The file content with the documentation metadata removed is:

1
12#pragma once
13
14#include "builder/Node.h"
15#include "builder/NodeContractConfigurable.h"
16#include "builder/NodeContractProvider.h"
17#include "builder/OutputSpec.h"
18#ifdef SIMA_NEAT_INTERNAL
19#include "model/internal/ModelRouteRetarget.h"
20#endif
21
22#include <memory>
23#include <string>
24#include <vector>
25
26namespace simaai::neat {
27struct CompiledProcessCvuContract;
28
30enum class CastDirection {
31 Bf16ToFp32 = 0,
32 Fp32ToBf16 = 1,
33};
34
40struct CastOptions {
42 std::string element_name;
43 bool silent = true;
44 std::shared_ptr<const CompiledProcessCvuContract>
46 int num_buffers = 0;
47#ifdef SIMA_NEAT_INTERNAL
48 std::shared_ptr<const simaai::neat::internal::ModelLineageBinding> model_lineage;
49#endif
50};
51
64class Cast final : public Node,
65 public OutputSpecProvider,
66 public NodeContractProvider,
67 public NodeContractConfigurable {
68public:
70 explicit Cast(CastOptions opt = {});
71
73 std::string kind() const override {
74 return "Cast";
75 }
77 NodeCapsBehavior caps_behavior() const override {
78 return NodeCapsBehavior::Static;
79 }
81 std::string backend_fragment(int node_index) const override;
83 std::vector<std::string> element_names(int node_index) const override;
85 OutputSpec output_spec(const OutputSpec& input) const override;
89 bool compile_node_contract(const ContractCompileInput& input, CompiledNodeContract* out,
90 std::string* err) const override;
92 void apply_compiled_contract(const CompiledNodeContract& contract, std::string* err) override;
93
95 const CastOptions& options() const {
96 return opt_;
97 }
98
99private:
100 CastOptions opt_;
101};
102
103} // namespace simaai::neat
104
105namespace simaai::neat::nodes {
107std::shared_ptr<simaai::neat::Node> Cast(CastOptions opt = {});
108} // namespace simaai::neat::nodes

Generated via doxygen2docusaurus 2.0.0 by Doxygen 1.9.8.