forked from facebookresearch/pyvrs
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Re-sync with internal repository (facebookresearch#133)
The internal and external repositories are out of sync. This Pull Request attempts to brings them back in sync by patching the GitHub repository. Please carefully review this patch. You must disable ShipIt for your project in order to merge this pull request. DO NOT IMPORT this pull request. Instead, merge it directly on GitHub using the MERGE BUTTON. Re-enable ShipIt after merging.
- Loading branch information
1 parent
bb275d4
commit baaef19
Showing
15 changed files
with
1,786 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
/* | ||
* Copyright (c) Meta Platforms, Inc. and affiliates. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#pragma once | ||
|
||
#include <pybind11/pybind11.h> | ||
|
||
#include <map> | ||
#include <string> | ||
#include <vector> | ||
|
||
#include <vrs/DataPieces.h> | ||
|
||
namespace pyvrs { | ||
|
||
namespace py = pybind11; | ||
using namespace vrs; | ||
|
||
class DataPieceWrapper { | ||
public: | ||
virtual ~DataPieceWrapper() = default; | ||
}; | ||
|
||
// We should treat DataPieceString separatedly from other DataPiece types. | ||
class DataPieceStringWrapper : public DataPieceWrapper { | ||
public: | ||
void set(std::string& v) { | ||
dpv_->stage(v); | ||
} | ||
void setDataPiece(DataPieceString* dpv) { | ||
dpv_ = dpv; | ||
} | ||
|
||
private: | ||
DataPieceString* dpv_; | ||
}; | ||
|
||
#define DEFINE_DATA_PIECE_VALUE_WRAPPER(TEMPLATE_TYPE) \ | ||
class DataPieceValue##TEMPLATE_TYPE##Wrapper : public DataPieceWrapper { \ | ||
public: \ | ||
void set(TEMPLATE_TYPE& v) { \ | ||
dpv_->set(v); \ | ||
} \ | ||
void setDataPiece(DataPieceValue<TEMPLATE_TYPE>* dpv) { \ | ||
dpv_ = dpv; \ | ||
} \ | ||
\ | ||
private: \ | ||
DataPieceValue<TEMPLATE_TYPE>* dpv_; \ | ||
}; | ||
|
||
#define DEFINE_DATA_PIECE_VECTOR_WRAPPER(TEMPLATE_TYPE) \ | ||
class DataPieceVector##TEMPLATE_TYPE##Wrapper : public DataPieceWrapper { \ | ||
public: \ | ||
void set(std::vector<TEMPLATE_TYPE>& v) { \ | ||
dpv_->stage(v); \ | ||
} \ | ||
void setDataPiece(DataPieceVector<TEMPLATE_TYPE>* dpv) { \ | ||
dpv_ = dpv; \ | ||
} \ | ||
\ | ||
private: \ | ||
DataPieceVector<TEMPLATE_TYPE>* dpv_; \ | ||
}; | ||
|
||
#define DEFINE_DATA_PIECE_ARRAY_WRAPPER(TEMPLATE_TYPE) \ | ||
class DataPieceArray##TEMPLATE_TYPE##Wrapper : public DataPieceWrapper { \ | ||
public: \ | ||
void set(std::vector<TEMPLATE_TYPE>& v) { \ | ||
if (v.size() > dpv_->getArraySize()) { \ | ||
throw py::value_error("Given array does not fit in target field " + dpv_->getLabel()); \ | ||
} \ | ||
dpv_->set(v); \ | ||
} \ | ||
void setDataPiece(DataPieceArray<TEMPLATE_TYPE>* dpv) { \ | ||
dpv_ = dpv; \ | ||
} \ | ||
\ | ||
private: \ | ||
DataPieceArray<TEMPLATE_TYPE>* dpv_; \ | ||
}; | ||
|
||
#define DEFINE_DATA_PIECE_MAP_WRAPPER(TEMPLATE_TYPE) \ | ||
class DataPieceStringMap##TEMPLATE_TYPE##Wrapper : public DataPieceWrapper { \ | ||
public: \ | ||
void set(std::map<std::string, TEMPLATE_TYPE>& v) { \ | ||
dpv_->stage(v); \ | ||
} \ | ||
void setDataPiece(DataPieceStringMap<TEMPLATE_TYPE>* dpv) { \ | ||
dpv_ = dpv; \ | ||
} \ | ||
\ | ||
private: \ | ||
DataPieceStringMap<TEMPLATE_TYPE>* dpv_; \ | ||
}; | ||
|
||
#define DEFINE_ALL_DATA_PIECE_WRAPPER(TEMPLATE_TYPE) \ | ||
DEFINE_DATA_PIECE_VALUE_WRAPPER(TEMPLATE_TYPE) \ | ||
DEFINE_DATA_PIECE_VECTOR_WRAPPER(TEMPLATE_TYPE) \ | ||
DEFINE_DATA_PIECE_ARRAY_WRAPPER(TEMPLATE_TYPE) \ | ||
DEFINE_DATA_PIECE_MAP_WRAPPER(TEMPLATE_TYPE) | ||
|
||
// Define DataPieceWrapper classes | ||
// Define & generate the code for each POD type supported. | ||
#define POD_MACRO DEFINE_ALL_DATA_PIECE_WRAPPER | ||
#include <vrs/helpers/PODMacro.inc> | ||
|
||
DEFINE_DATA_PIECE_MAP_WRAPPER(string) | ||
DEFINE_DATA_PIECE_VECTOR_WRAPPER(string) | ||
|
||
} // namespace pyvrs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,267 @@ | ||
/* | ||
* Copyright (c) Meta Platforms, Inc. and affiliates. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
// This include *must* be before any STL include! See Python C API doc. | ||
#define PY_SSIZE_T_CLEAN | ||
#include <Python.h> // IWYU pragma: keepo | ||
|
||
// Includes needed for bindings (including marshalling STL containers) | ||
#include <pybind11/attr.h> | ||
#include <pybind11/cast.h> | ||
#include <pybind11/pybind11.h> | ||
#include <pybind11/stl.h> | ||
|
||
#define DEFAULT_LOG_CHANNEL "PyRecordable" | ||
#include <logging/Log.h> | ||
|
||
#include <vrs/RecordFileWriter.h> | ||
#include <vrs/Recordable.h> | ||
|
||
#include "PyRecordable.h" | ||
#include "VRSWriter.h" | ||
|
||
namespace py = pybind11; | ||
|
||
namespace pyvrs { | ||
|
||
bool PyRecordable::addRecordFormat(const PyRecordFormat* recordFormat) { | ||
return Recordable::addRecordFormat( | ||
recordFormat->getRecordType(), | ||
recordFormat->getFormatVersion(), | ||
recordFormat->getRecordFormat(), | ||
recordFormat->getDataLayouts()); | ||
} | ||
|
||
DataSource PyRecordFormat::getDataSource( | ||
const DataSourceChunk& src1, | ||
const DataSourceChunk& src2, | ||
const DataSourceChunk& src3) const { | ||
if (dataLayouts_.size() == 1) { | ||
return DataSource(*dataLayouts_[0], src1, src2, src3); | ||
} | ||
if (dataLayouts_.size() == 2) { | ||
return DataSource(*dataLayouts_[0], *dataLayouts_[1], src1, src2, src3); | ||
} | ||
return DataSource(src1, src2, src3); | ||
} | ||
|
||
const std::vector<const vrs::DataLayout*> PyRecordFormat::getDataLayouts() const { | ||
std::vector<const vrs::DataLayout*> dataLayouts = {}; | ||
for (auto& dataLayout : dataLayouts_) { | ||
dataLayouts.emplace_back(dataLayout.get()); | ||
} | ||
return dataLayouts; | ||
} | ||
|
||
RecordFormat PyRecordFormat::getRecordFormat() const { | ||
RecordFormat format; | ||
// default behavior of DataSource::copyTo is to copy all dataLayouts buffer first, | ||
// then copy other buffers. To make things simple, we will use the same order here. | ||
for (auto& dataLayout : dataLayouts_) { | ||
format = format + dataLayout->getContentBlock(); | ||
} | ||
for (auto& contentBlocks : additionalContentBlocks_) { | ||
for (ContentBlock contentBlock : contentBlocks) { | ||
format = format + contentBlock; | ||
} | ||
} | ||
return format; | ||
} | ||
|
||
#define ADD_DATA_PIECE_WRAPPER_SUPPORT(DATA_PIECE_TYPE, TEMPLATE_TYPE) \ | ||
if (piece->getTypeName() == "DataPiece" #DATA_PIECE_TYPE "<" #TEMPLATE_TYPE ">") { \ | ||
std::unique_ptr<pyvrs::DataPiece##DATA_PIECE_TYPE##TEMPLATE_TYPE##Wrapper> ptr = \ | ||
std::make_unique<pyvrs::DataPiece##DATA_PIECE_TYPE##TEMPLATE_TYPE##Wrapper>(); \ | ||
ptr->setDataPiece(reinterpret_cast<pyvrs::DataPiece##DATA_PIECE_TYPE<TEMPLATE_TYPE>*>(piece)); \ | ||
dataPieceMaps[i][piece->getLabel()] = std::move(ptr); \ | ||
} | ||
|
||
#define ADD_ALL_DATA_PIECE_WRAPPER_SUPPORT(TEMPLATE_TYPE) \ | ||
ADD_DATA_PIECE_WRAPPER_SUPPORT(Value, TEMPLATE_TYPE) \ | ||
ADD_DATA_PIECE_WRAPPER_SUPPORT(Array, TEMPLATE_TYPE) \ | ||
ADD_DATA_PIECE_WRAPPER_SUPPORT(Vector, TEMPLATE_TYPE) \ | ||
ADD_DATA_PIECE_WRAPPER_SUPPORT(StringMap, TEMPLATE_TYPE) | ||
|
||
std::vector<std::map<std::string, std::unique_ptr<pyvrs::DataPieceWrapper>>> | ||
PyRecordFormat::getMembers() { | ||
std::vector<std::map<std::string, std::unique_ptr<pyvrs::DataPieceWrapper>>> dataPieceMaps( | ||
dataLayouts_.size()); | ||
for (size_t i = 0; i < dataLayouts_.size(); i++) { | ||
dataLayouts_[i]->forEachDataPiece([&dataPieceMaps, &i](vrs::DataPiece* piece) { | ||
// Define & generate the code for each POD type supported. | ||
#define POD_MACRO ADD_ALL_DATA_PIECE_WRAPPER_SUPPORT | ||
#include <vrs/helpers/PODMacro.inc> | ||
ADD_DATA_PIECE_WRAPPER_SUPPORT(Vector, string) | ||
ADD_DATA_PIECE_WRAPPER_SUPPORT(StringMap, string) | ||
|
||
if (piece->getTypeName() == "DataPieceString") { | ||
std::unique_ptr<pyvrs::DataPieceStringWrapper> ptr = | ||
std::make_unique<pyvrs::DataPieceStringWrapper>(); | ||
ptr->setDataPiece(reinterpret_cast<vrs::DataPieceString*>(piece)); | ||
dataPieceMaps[i][piece->getLabel()] = std::move(ptr); | ||
} | ||
}); | ||
} | ||
|
||
return dataPieceMaps; | ||
} | ||
|
||
void PyStream::init( | ||
RecordableTypeId typeId, | ||
const string& deviceFlavor, | ||
std::unique_ptr<PyRecordFormat>&& configurationRecordFormat, | ||
std::unique_ptr<PyRecordFormat>&& dataRecordFormat, | ||
std::unique_ptr<PyRecordFormat>&& stateRecordFormat) { | ||
recordable_ = std::make_unique<PyRecordable>(typeId, deviceFlavor); | ||
if (configurationRecordFormat != nullptr) { | ||
recordFormatMap_.insert( | ||
std::make_pair(Record::Type::CONFIGURATION, std::move(configurationRecordFormat))); | ||
} else { | ||
recordFormatMap_.insert(std::make_pair( | ||
Record::Type::CONFIGURATION, | ||
std::make_unique<PyRecordFormat>(Record::Type::CONFIGURATION))); | ||
} | ||
if (dataRecordFormat != nullptr) { | ||
recordFormatMap_.insert(std::make_pair(Record::Type::DATA, std::move(dataRecordFormat))); | ||
} else { | ||
recordFormatMap_.insert( | ||
std::make_pair(Record::Type::DATA, std::make_unique<PyRecordFormat>(Record::Type::DATA))); | ||
} | ||
if (stateRecordFormat != nullptr) { | ||
recordFormatMap_.insert(std::make_pair(Record::Type::STATE, std::move(stateRecordFormat))); | ||
} else { | ||
recordFormatMap_.insert( | ||
std::make_pair(Record::Type::STATE, std::make_unique<PyRecordFormat>(Record::Type::STATE))); | ||
} | ||
} | ||
|
||
PyStream::PyStream(PyStream&& other) { | ||
other.recordable_ = std::move(recordable_); | ||
for (auto& recordFormat : recordFormatMap_) { | ||
other.recordFormatMap_.insert( | ||
std::make_pair(recordFormat.first, std::move(recordFormat.second))); | ||
} | ||
} | ||
|
||
PyRecordFormat* PyStream::createRecordFormat(Record::Type recordType) { | ||
PyRecordFormat* recordFormat = nullptr; | ||
auto iter = recordFormatMap_.find(recordType); | ||
if (iter != recordFormatMap_.end()) { | ||
recordFormat = iter->second.get(); | ||
recordable_->addRecordFormat(recordFormat); | ||
return recordFormat; | ||
} | ||
return recordFormat; | ||
} | ||
|
||
const Record* PyStream::createRecord(double timestamp, const PyRecordFormat* recordFormat) { | ||
return recordable_->createRecord( | ||
timestamp, | ||
recordFormat->getRecordType(), | ||
recordFormat->getFormatVersion(), | ||
recordFormat->getDataSource()); | ||
} | ||
|
||
const Record* | ||
PyStream::createRecord(double timestamp, const PyRecordFormat* recordFormat, py::array buffer) { | ||
py::buffer_info info = buffer.request(); | ||
size_t size = info.itemsize; | ||
for (py::ssize_t i = 0; i < info.ndim; i++) { | ||
size *= info.shape[i]; | ||
} | ||
return recordable_->createRecord( | ||
timestamp, | ||
recordFormat->getRecordType(), | ||
recordFormat->getFormatVersion(), | ||
recordFormat->getDataSource(DataSourceChunk(info.ptr, size))); | ||
} | ||
|
||
const Record* PyStream::createRecord( | ||
double timestamp, | ||
const PyRecordFormat* recordFormat, | ||
py::array buffer, | ||
py::array buffer2) { | ||
py::buffer_info info = buffer.request(); | ||
py::buffer_info info2 = buffer2.request(); | ||
size_t size = info.itemsize; | ||
for (py::ssize_t i = 0; i < info.ndim; i++) { | ||
size *= info.shape[i]; | ||
} | ||
size_t size2 = info2.itemsize; | ||
for (py::ssize_t i = 0; i < info2.ndim; i++) { | ||
size2 *= info2.shape[i]; | ||
} | ||
return recordable_->createRecord( | ||
timestamp, | ||
recordFormat->getRecordType(), | ||
recordFormat->getFormatVersion(), | ||
recordFormat->getDataSource( | ||
DataSourceChunk(info.ptr, size), DataSourceChunk(info2.ptr, size2))); | ||
} | ||
|
||
const Record* PyStream::createRecord( | ||
double timestamp, | ||
const PyRecordFormat* recordFormat, | ||
py::array buffer, | ||
py::array buffer2, | ||
py::array buffer3) { | ||
py::buffer_info info = buffer.request(); | ||
py::buffer_info info2 = buffer2.request(); | ||
py::buffer_info info3 = buffer3.request(); | ||
size_t size = info.itemsize; | ||
for (py::ssize_t i = 0; i < info.ndim; i++) { | ||
size *= info.shape[i]; | ||
} | ||
size_t size2 = info2.itemsize; | ||
for (py::ssize_t i = 0; i < info2.ndim; i++) { | ||
size2 *= info2.shape[i]; | ||
} | ||
size_t size3 = info3.itemsize; | ||
for (py::ssize_t i = 0; i < info3.ndim; i++) { | ||
size3 *= info3.shape[i]; | ||
} | ||
return recordable_->createRecord( | ||
timestamp, | ||
recordFormat->getRecordType(), | ||
recordFormat->getFormatVersion(), | ||
recordFormat->getDataSource( | ||
DataSourceChunk(info.ptr, size), | ||
DataSourceChunk(info2.ptr, size2), | ||
DataSourceChunk(info3.ptr, size3))); | ||
} | ||
|
||
void PyStream::setCompression(CompressionPreset preset) { | ||
recordable_->setCompression(preset); | ||
} | ||
|
||
void PyStream::setTag(const std::string& tagName, const std::string& tagValue) { | ||
recordable_->setTag(tagName, tagValue); | ||
} | ||
|
||
std::string PyStream::getStreamID() { | ||
return recordable_->getStreamId().getNumericName(); | ||
} | ||
|
||
std::vector<std::string> PyRecordFormat::getJsonDataLayouts() const { | ||
std::vector<std::string> v; | ||
for (auto& dataLayout : dataLayouts_) { | ||
v.push_back(dataLayout->asJson()); | ||
} | ||
|
||
return v; | ||
} | ||
|
||
} // namespace pyvrs |
Oops, something went wrong.