diff --git a/postgis/geometry.py b/postgis/geometry.py index b6a5c87..33494bc 100644 --- a/postgis/geometry.py +++ b/postgis/geometry.py @@ -1,5 +1,8 @@ import warnings +from .ewkb import Reader, Typed, Writer +from .geojson import GeoJSON + try: # Do not make psycopg2 a requirement. from psycopg2.extensions import ISQLQuote @@ -7,8 +10,6 @@ warnings.warn('psycopg2 not installed', ImportWarning) -from .ewkb import Reader, Typed, Writer -from .geojson import GeoJSON class Geometry(object, metaclass=Typed): @@ -56,9 +57,13 @@ def __repr__(self): return '<{} {}>'.format(self.__class__.__name__, self.wkt) def __eq__(self, other): - if isinstance(other, self.__class__): - other = other.coords - return self.coords == other + return hash(self) == hash(other) + + def __hash__(self): + values = self.coords + if self.srid: + values = values + (self.srid,) + return hash(values) @property def name(self): diff --git a/postgis/geometrycollection.py b/postgis/geometrycollection.py index 93f24f2..8c07ce7 100644 --- a/postgis/geometrycollection.py +++ b/postgis/geometrycollection.py @@ -1,18 +1,17 @@ -from .geometry import Geometry from .geojson import GeoJSON +from .geometry import Geometry class GeometryCollection(Geometry): TYPE = 7 - def __init__(self, geoms, srid=None): + def __init__(self, geoms, srid=0): for geom in geoms: if not isinstance(geom, Geometry): raise ValueError('{} is not instance of Geometry'.format(geom)) self.geoms = list(geoms) - if srid: - self.srid = srid + self.srid = srid def __iter__(self): return self.geoms.__iter__() diff --git a/postgis/multi.py b/postgis/multi.py index cca266f..43287b1 100644 --- a/postgis/multi.py +++ b/postgis/multi.py @@ -1,15 +1,15 @@ from .geometry import Geometry from .point import Point + class Multi(Geometry): - __slots__ = ['geoms', 'srid'] + __slots__ = ["geoms", "srid"] SUBCLASS = None - def __init__(self, geoms, srid=None): + def __init__(self, geoms, srid=0): self.geoms = [self.SUBCLASS(g, srid=srid) for g in geoms] - if srid: - self.srid = srid + self.srid = srid def __iter__(self): return iter(self.geoms) @@ -31,8 +31,8 @@ def from_ewkb_body(cls, reader, srid=None): @property def wkt_coords(self): - fmt = '{}' if self.SUBCLASS == Point else '({})' - return ', '.join(fmt.format(g.wkt_coords) for g in self) + fmt = "{}" if self.SUBCLASS == Point else "({})" + return ", ".join(fmt.format(g.wkt_coords) for g in self) def write_ewkb_body(self, writer): writer.write_int(len(self.geoms)) diff --git a/postgis/point.py b/postgis/point.py index 3481b30..9590c8c 100644 --- a/postgis/point.py +++ b/postgis/point.py @@ -6,7 +6,7 @@ class Point(Geometry): __slots__ = ['x', 'y', 'z', 'm', 'srid'] TYPE = 1 - def __init__(self, x, y=None, z=None, m=None, srid=None): + def __init__(self, x, y=None, z=None, m=None, srid=0): if y is None and isinstance(x, (tuple, list)): x, y, *extra = x if extra: @@ -17,8 +17,7 @@ def __init__(self, x, y=None, z=None, m=None, srid=None): self.y = float(y) self.z = float(z) if z is not None else None self.m = float(m) if m is not None else None - if srid is not None: - self.srid = srid + self.srid = srid def __getitem__(self, item): if item in (0, 'x'): diff --git a/tests/test_linestring.py b/tests/test_linestring.py index 24e6192..7b75b21 100644 --- a/tests/test_linestring.py +++ b/tests/test_linestring.py @@ -21,3 +21,13 @@ def test_geom_should_compare_with_coords(): def test_linestring_get_item(): line = LineString(((30, 10), (10, 30), (40, 40))) assert line[0] == (30, 10) + + +def test_linestring_is_hashable(): + l1 = LineString(((1, 2), (3, 4))) + l2 = LineString(((1, 2), (3, 4))) + l3 = LineString(((3, 4), (5, 6))) + assert {l1, l2, l3} == {l1, l3} + l1 = LineString(((1, 2), (3, 4)), srid=4326) + l2 = LineString(((1, 2), (3, 4)), srid=3857) + assert len({l1, l2}) == 2 diff --git a/tests/test_point.py b/tests/test_point.py index 4de9d69..2591a3c 100644 --- a/tests/test_point.py +++ b/tests/test_point.py @@ -98,3 +98,13 @@ def test_0_as_m_is_considered(): assert point.y == 2.0 assert point.z == 3 assert point.m == 0 + + +def test_point_is_hashable(): + p1 = Point(1, 1) + p2 = Point(1, 1) + p3 = Point(2, 2) + assert {p1, p2, p3} == {p1, p3} + p1 = Point(1, 1, srid=4326) + p2 = Point(1, 1, srid=3857) + assert len({p1, p2}) == 2 diff --git a/tests/test_polygon.py b/tests/test_polygon.py index 17854cc..0222c5c 100644 --- a/tests/test_polygon.py +++ b/tests/test_polygon.py @@ -2,18 +2,38 @@ def test_geom_should_compare_with_coords(): - assert (((35, 10), (45, 45), (15, 40), (10, 20), (35, 10)), ((20, 30), (35, 35), (30, 20), (20, 30))) == Polygon((((35, 10), (45, 45), (15, 40), (10, 20), (35, 10)), ((20, 30), (35, 35), (30, 20), (20, 30)))) # noqa + assert ( + ((35, 10), (45, 45), (15, 40), (10, 20), (35, 10)), + ((20, 30), (35, 35), (30, 20), (20, 30)), + ) == Polygon( + ( + ((35, 10), (45, 45), (15, 40), (10, 20), (35, 10)), + ((20, 30), (35, 35), (30, 20), (20, 30)), + ) + ) # noqa def test_polygon_geojson(): poly = Polygon((((1, 2), (3, 4), (5, 6), (1, 2)),)) - assert poly.geojson == {"type": "Polygon", - "coordinates": (((1, 2), (3, 4), (5, 6), (1, 2)),)} + assert poly.geojson == { + "type": "Polygon", + "coordinates": (((1, 2), (3, 4), (5, 6), (1, 2)),), + } def test_polygon_wkt(): poly = Polygon((((1, 2), (3, 4), (5, 6), (1, 2)),)) wkt = poly.wkt - wkt = wkt.replace('.0','') - wkt = wkt.replace(', ',',') - assert wkt == 'POLYGON((1 2,3 4,5 6,1 2))' + wkt = wkt.replace(".0", "") + wkt = wkt.replace(", ", ",") + assert wkt == "POLYGON((1 2,3 4,5 6,1 2))" + + +def test_polygon_is_hashable(): + p1 = Polygon((((1, 2), (3, 4), (5, 6), (1, 2)),)) + p2 = Polygon((((1, 2), (3, 4), (5, 6), (1, 2)),)) + p3 = Polygon((((1, 2), (3, 4), (6, 7), (1, 2)),)) + assert {p1, p2, p3} == {p1, p3} + p1 = Polygon((((1, 2), (3, 4), (5, 6), (1, 2)),), srid=4326) + p2 = Polygon((((1, 2), (3, 4), (5, 6), (1, 2)),), srid=3857) + assert len({p1, p2}) == 2