Use unique_ptr for matrix storage

This commit is contained in:
Danny Robson 2013-08-26 15:10:06 +10:00
parent 03eab1354a
commit be53ad9c75
2 changed files with 31 additions and 36 deletions

View File

@ -29,13 +29,10 @@ using namespace util;
using namespace maths; using namespace maths;
matrix::matrix (size_t _rows, size_t _columns): matrix::matrix (size_t _rows, size_t _columns):
m_rows (_rows), m_rows (_rows),
m_columns (_columns), m_columns (_columns),
m_data (NULL) { m_data (new double[_rows * _columns])
if (m_rows <= 0 || m_columns <= 0) {
throw std::runtime_error ("rows and columns must be positive");
m_data = new double[size ()];
} }
@ -43,17 +40,14 @@ matrix::matrix (size_t _rows,
size_t _columns, size_t _columns,
const std::initializer_list <double> &_data): const std::initializer_list <double> &_data):
m_rows (_rows), m_rows (_rows),
m_columns (_columns), m_columns (_columns)
m_data (NULL)
{ {
if (m_rows <= 0 || m_columns <= 0)
throw std::runtime_error ("rows and columns must be positive");
if (size () != _data.size ()) if (size () != _data.size ())
throw std::runtime_error ("element and initializer size differs"); throw std::runtime_error ("element and initializer size differs");
CHECK_HARD (m_rows * m_columns == _data.size()); CHECK_HARD (m_rows * m_columns == _data.size());
m_data = new double[size ()]; m_data.reset (new double[size ()]);
std::copy (_data.begin (), _data.end (), m_data); std::copy (_data.begin (), _data.end (), m_data.get ());
} }
@ -62,7 +56,7 @@ matrix::matrix (const std::initializer_list <vector> &rhs):
m_columns (rhs.begin()->size ()), m_columns (rhs.begin()->size ()),
m_data (new double[m_rows * m_columns]) m_data (new double[m_rows * m_columns])
{ {
double *row_cursor = m_data; double *row_cursor = m_data.get ();
for (auto i = rhs.begin (); i != rhs.end (); ++i) { for (auto i = rhs.begin (); i != rhs.end (); ++i) {
CHECK (i->size () == m_columns); CHECK (i->size () == m_columns);
@ -74,62 +68,62 @@ matrix::matrix (const std::initializer_list <vector> &rhs):
matrix::matrix (const matrix &rhs): matrix::matrix (const matrix &rhs):
m_rows (rhs.m_rows), m_rows (rhs.m_rows),
m_columns (rhs.m_columns) { m_columns (rhs.m_columns)
m_data = new double [m_rows * m_columns]; {
std::copy (rhs.m_data, rhs.m_data + m_rows * m_columns, m_data); m_data.reset (new double [m_rows * m_columns]);
std::copy (rhs.m_data.get (), rhs.m_data.get () + m_rows * m_columns, m_data.get ());
} }
matrix::matrix (matrix &&rhs): matrix::matrix (matrix &&rhs):
m_rows (rhs.m_rows), m_rows (rhs.m_rows),
m_columns (rhs.m_columns), m_columns (rhs.m_columns),
m_data (rhs.m_data) { m_data (std::move (rhs.m_data))
rhs.m_data = NULL; {
} }
matrix::~matrix() matrix::~matrix()
{ delete [] m_data; } { ; }
void void
matrix::sanity (void) const { matrix::sanity (void) const {
CHECK (m_rows > 0); CHECK (m_rows > 0);
CHECK (m_columns > 0); CHECK (m_columns > 0);
CHECK (m_data != NULL); CHECK (m_data != nullptr);
} }
const double * const double *
matrix::operator [] (unsigned int row) const { matrix::operator [] (unsigned int row) const {
CHECK_HARD (row < m_rows); CHECK_HARD (row < m_rows);
return m_data + row * m_columns; return m_data.get () + row * m_columns;
} }
double * double *
matrix::operator [] (unsigned int row) { matrix::operator [] (unsigned int row) {
CHECK_HARD (row < m_rows); CHECK_HARD (row < m_rows);
return m_data + row * m_columns; return m_data.get () + row * m_columns;
} }
const double * const double *
matrix::data (void) const matrix::data (void) const
{ return m_data; } { return m_data.get (); }
matrix& matrix&
matrix::operator =(const matrix& rhs) { matrix::operator =(const matrix& rhs) {
if (size () != rhs.size ()) { if (size () != rhs.size ()) {
delete [] m_data; m_data.reset (new double [rhs.rows () * rhs.columns ()]);
m_data = new double [m_rows * m_columns];
} }
m_rows = rhs.m_rows; m_rows = rhs.m_rows;
m_columns = rhs.m_columns; m_columns = rhs.m_columns;
std::copy (rhs.m_data, rhs.m_data + m_rows * m_columns, m_data); std::copy (rhs.m_data.get (), rhs.m_data.get () + m_rows * m_columns, m_data.get ());
return *this; return *this;
} }
@ -210,7 +204,7 @@ matrix::operator ==(const matrix& rhs) const {
rhs.columns () != columns ()) rhs.columns () != columns ())
return false; return false;
return std::equal (m_data, m_data + size (), rhs.data ()); return std::equal (m_data.get (), m_data.get () + size (), rhs.data ());
} }
@ -466,7 +460,7 @@ matrix::zeroes (size_t diag)
matrix matrix
matrix::zeroes (size_t rows, size_t columns) { matrix::zeroes (size_t rows, size_t columns) {
matrix m (rows, columns); matrix m (rows, columns);
std::fill (m.m_data, m.m_data + m.size (), 0.0); std::fill (m.m_data.get (), m.m_data.get () + m.size (), 0.0);
return m; return m;
} }

View File

@ -22,18 +22,19 @@
#include "vector.hpp" #include "vector.hpp"
#include <assert.h>
#include <algorithm> #include <algorithm>
#include <stdexcept> #include <assert.h>
#include <initializer_list> #include <initializer_list>
#include <iostream> #include <iostream>
#include <memory>
#include <stdexcept>
namespace maths { namespace maths {
class matrix { class matrix {
protected: protected:
size_t m_rows, size_t m_rows,
m_columns; m_columns;
double *restrict m_data; std::unique_ptr<double[]> m_data;
public: public:
matrix (size_t _rows, size_t _columns); matrix (size_t _rows, size_t _columns);