From 87948b68fc742d5fbea3856c1f75ae2e99dd13aa Mon Sep 17 00:00:00 2001 From: Mitchell Hashimoto Date: Fri, 16 Jan 2015 09:32:15 -0800 Subject: [PATCH] helper/schema: use interface for equality check /cc @svanharmelen --- helper/schema/equal.go | 6 ++++++ helper/schema/resource_data.go | 10 +++++----- helper/schema/set.go | 10 ++++++++++ 3 files changed, 21 insertions(+), 5 deletions(-) create mode 100644 helper/schema/equal.go diff --git a/helper/schema/equal.go b/helper/schema/equal.go new file mode 100644 index 000000000..d5e20e038 --- /dev/null +++ b/helper/schema/equal.go @@ -0,0 +1,6 @@ +package schema + +// Equal is an interface that checks for deep equality between two objects. +type Equal interface { + Equal(interface{}) bool +} diff --git a/helper/schema/resource_data.go b/helper/schema/resource_data.go index 35107f3a4..924b9a877 100644 --- a/helper/schema/resource_data.go +++ b/helper/schema/resource_data.go @@ -106,11 +106,11 @@ func (d *ResourceData) getRaw(key string, level getSource) getResult { func (d *ResourceData) HasChange(key string) bool { o, n := d.GetChange(key) - // There is a special case needed for *schema.Set's as they contain - // a function and reflect.DeepEqual will only say two functions are - // equal when they are both nil (which in this case they are not). - if reflect.TypeOf(o).String() == "*schema.Set" { - return !reflect.DeepEqual(o.(*Set).m, n.(*Set).m) + // If the type implements the Equal interface, then call that + // instead of just doing a reflect.DeepEqual. An example where this is + // needed is *Set + if eq, ok := o.(Equal); ok { + return !eq.Equal(n) } return !reflect.DeepEqual(o, n) diff --git a/helper/schema/set.go b/helper/schema/set.go index 78c68bdb7..965ad6f8b 100644 --- a/helper/schema/set.go +++ b/helper/schema/set.go @@ -2,6 +2,7 @@ package schema import ( "fmt" + "reflect" "sort" "sync" ) @@ -101,6 +102,15 @@ func (s *Set) Union(other *Set) *Set { return result } +func (s *Set) Equal(raw interface{}) bool { + other, ok := raw.(*Set) + if !ok { + return false + } + + return reflect.DeepEqual(s.m, other.m) +} + func (s *Set) GoString() string { return fmt.Sprintf("*Set(%#v)", s.m) }