xref: /src/contrib/llvm-project/llvm/lib/Analysis/TensorSpec.cpp (revision 06c3fb2749bda94cb5201f81ffdb8fa6c3161b2e)
1145449b1SDimitry Andric //===- TensorSpec.cpp - tensor type abstraction ---------------------------===//
2145449b1SDimitry Andric //
3145449b1SDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4145449b1SDimitry Andric // See https://llvm.org/LICENSE.txt for license information.
5145449b1SDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6145449b1SDimitry Andric //
7145449b1SDimitry Andric //===----------------------------------------------------------------------===//
8145449b1SDimitry Andric //
9145449b1SDimitry Andric // Implementation file for the abstraction of a tensor type, and JSON loading
10145449b1SDimitry Andric // utils.
11145449b1SDimitry Andric //
12145449b1SDimitry Andric //===----------------------------------------------------------------------===//
137fa27ce4SDimitry Andric #include "llvm/ADT/STLExtras.h"
14145449b1SDimitry Andric #include "llvm/Config/config.h"
15145449b1SDimitry Andric 
167fa27ce4SDimitry Andric #include "llvm/ADT/StringExtras.h"
17145449b1SDimitry Andric #include "llvm/ADT/Twine.h"
18145449b1SDimitry Andric #include "llvm/Analysis/TensorSpec.h"
19145449b1SDimitry Andric #include "llvm/Support/CommandLine.h"
20145449b1SDimitry Andric #include "llvm/Support/Debug.h"
21145449b1SDimitry Andric #include "llvm/Support/JSON.h"
22145449b1SDimitry Andric #include "llvm/Support/ManagedStatic.h"
23145449b1SDimitry Andric #include "llvm/Support/raw_ostream.h"
24e3b55780SDimitry Andric #include <array>
25145449b1SDimitry Andric #include <cassert>
26145449b1SDimitry Andric #include <numeric>
27145449b1SDimitry Andric 
28145449b1SDimitry Andric using namespace llvm;
29145449b1SDimitry Andric 
30145449b1SDimitry Andric namespace llvm {
31145449b1SDimitry Andric 
32145449b1SDimitry Andric #define TFUTILS_GETDATATYPE_IMPL(T, E)                                         \
33145449b1SDimitry Andric   template <> TensorType TensorSpec::getDataType<T>() { return TensorType::E; }
34145449b1SDimitry Andric 
SUPPORTED_TENSOR_TYPES(TFUTILS_GETDATATYPE_IMPL)35145449b1SDimitry Andric SUPPORTED_TENSOR_TYPES(TFUTILS_GETDATATYPE_IMPL)
36145449b1SDimitry Andric 
37145449b1SDimitry Andric #undef TFUTILS_GETDATATYPE_IMPL
38145449b1SDimitry Andric 
39e3b55780SDimitry Andric static std::array<std::string, static_cast<size_t>(TensorType::Total)>
40e3b55780SDimitry Andric     TensorTypeNames{"INVALID",
41e3b55780SDimitry Andric #define TFUTILS_GETNAME_IMPL(T, _) #T,
42e3b55780SDimitry Andric                     SUPPORTED_TENSOR_TYPES(TFUTILS_GETNAME_IMPL)
43e3b55780SDimitry Andric #undef TFUTILS_GETNAME_IMPL
44e3b55780SDimitry Andric     };
45e3b55780SDimitry Andric 
toString(TensorType TT)46e3b55780SDimitry Andric StringRef toString(TensorType TT) {
47e3b55780SDimitry Andric   return TensorTypeNames[static_cast<size_t>(TT)];
48e3b55780SDimitry Andric }
49e3b55780SDimitry Andric 
toJSON(json::OStream & OS) const50e3b55780SDimitry Andric void TensorSpec::toJSON(json::OStream &OS) const {
51e3b55780SDimitry Andric   OS.object([&]() {
52e3b55780SDimitry Andric     OS.attribute("name", name());
53e3b55780SDimitry Andric     OS.attribute("type", toString(type()));
54e3b55780SDimitry Andric     OS.attribute("port", port());
55e3b55780SDimitry Andric     OS.attributeArray("shape", [&]() {
56e3b55780SDimitry Andric       for (size_t D : shape())
57e3b55780SDimitry Andric         OS.value(static_cast<int64_t>(D));
58e3b55780SDimitry Andric     });
59e3b55780SDimitry Andric   });
60e3b55780SDimitry Andric }
61e3b55780SDimitry Andric 
TensorSpec(const std::string & Name,int Port,TensorType Type,size_t ElementSize,const std::vector<int64_t> & Shape)62145449b1SDimitry Andric TensorSpec::TensorSpec(const std::string &Name, int Port, TensorType Type,
63145449b1SDimitry Andric                        size_t ElementSize, const std::vector<int64_t> &Shape)
64145449b1SDimitry Andric     : Name(Name), Port(Port), Type(Type), Shape(Shape),
65145449b1SDimitry Andric       ElementCount(std::accumulate(Shape.begin(), Shape.end(), 1,
66145449b1SDimitry Andric                                    std::multiplies<int64_t>())),
67145449b1SDimitry Andric       ElementSize(ElementSize) {}
68145449b1SDimitry Andric 
getTensorSpecFromJSON(LLVMContext & Ctx,const json::Value & Value)69e3b55780SDimitry Andric std::optional<TensorSpec> getTensorSpecFromJSON(LLVMContext &Ctx,
70145449b1SDimitry Andric                                                 const json::Value &Value) {
71e3b55780SDimitry Andric   auto EmitError =
72e3b55780SDimitry Andric       [&](const llvm::Twine &Message) -> std::optional<TensorSpec> {
73145449b1SDimitry Andric     std::string S;
74145449b1SDimitry Andric     llvm::raw_string_ostream OS(S);
75145449b1SDimitry Andric     OS << Value;
76145449b1SDimitry Andric     Ctx.emitError("Unable to parse JSON Value as spec (" + Message + "): " + S);
77e3b55780SDimitry Andric     return std::nullopt;
78145449b1SDimitry Andric   };
79145449b1SDimitry Andric   // FIXME: accept a Path as a parameter, and use it for error reporting.
80145449b1SDimitry Andric   json::Path::Root Root("tensor_spec");
81145449b1SDimitry Andric   json::ObjectMapper Mapper(Value, Root);
82145449b1SDimitry Andric   if (!Mapper)
83145449b1SDimitry Andric     return EmitError("Value is not a dict");
84145449b1SDimitry Andric 
85145449b1SDimitry Andric   std::string TensorName;
86145449b1SDimitry Andric   int TensorPort = -1;
87145449b1SDimitry Andric   std::string TensorType;
88145449b1SDimitry Andric   std::vector<int64_t> TensorShape;
89145449b1SDimitry Andric 
90145449b1SDimitry Andric   if (!Mapper.map<std::string>("name", TensorName))
91145449b1SDimitry Andric     return EmitError("'name' property not present or not a string");
92145449b1SDimitry Andric   if (!Mapper.map<std::string>("type", TensorType))
93145449b1SDimitry Andric     return EmitError("'type' property not present or not a string");
94145449b1SDimitry Andric   if (!Mapper.map<int>("port", TensorPort))
95145449b1SDimitry Andric     return EmitError("'port' property not present or not an int");
96145449b1SDimitry Andric   if (!Mapper.map<std::vector<int64_t>>("shape", TensorShape))
97145449b1SDimitry Andric     return EmitError("'shape' property not present or not an int array");
98145449b1SDimitry Andric 
99145449b1SDimitry Andric #define PARSE_TYPE(T, E)                                                       \
100145449b1SDimitry Andric   if (TensorType == #T)                                                        \
101145449b1SDimitry Andric     return TensorSpec::createSpec<T>(TensorName, TensorShape, TensorPort);
102145449b1SDimitry Andric   SUPPORTED_TENSOR_TYPES(PARSE_TYPE)
103145449b1SDimitry Andric #undef PARSE_TYPE
104e3b55780SDimitry Andric   return std::nullopt;
105145449b1SDimitry Andric }
106145449b1SDimitry Andric 
tensorValueToString(const char * Buffer,const TensorSpec & Spec)1077fa27ce4SDimitry Andric std::string tensorValueToString(const char *Buffer, const TensorSpec &Spec) {
1087fa27ce4SDimitry Andric   switch (Spec.type()) {
1097fa27ce4SDimitry Andric #define _IMR_DBG_PRINTER(T, N)                                                 \
1107fa27ce4SDimitry Andric   case TensorType::N: {                                                        \
1117fa27ce4SDimitry Andric     const T *TypedBuff = reinterpret_cast<const T *>(Buffer);                  \
1127fa27ce4SDimitry Andric     auto R = llvm::make_range(TypedBuff, TypedBuff + Spec.getElementCount());  \
1137fa27ce4SDimitry Andric     return llvm::join(                                                         \
1147fa27ce4SDimitry Andric         llvm::map_range(R, [](T V) { return std::to_string(V); }), ",");       \
1157fa27ce4SDimitry Andric   }
1167fa27ce4SDimitry Andric     SUPPORTED_TENSOR_TYPES(_IMR_DBG_PRINTER)
1177fa27ce4SDimitry Andric #undef _IMR_DBG_PRINTER
1187fa27ce4SDimitry Andric   case TensorType::Total:
1197fa27ce4SDimitry Andric   case TensorType::Invalid:
1207fa27ce4SDimitry Andric     llvm_unreachable("invalid tensor type");
1217fa27ce4SDimitry Andric   }
1227fa27ce4SDimitry Andric   // To appease warnings about not all control paths returning a value.
1237fa27ce4SDimitry Andric   return "";
1247fa27ce4SDimitry Andric }
1257fa27ce4SDimitry Andric 
126145449b1SDimitry Andric } // namespace llvm
127