Imported Upstream version 60
[debian/goprotobuf.git] / proto / extensions.go
1 // Go support for Protocol Buffers - Google's data interchange format
2 //
3 // Copyright 2010 Google Inc.  All rights reserved.
4 // http://code.google.com/p/goprotobuf/
5 //
6 // Redistribution and use in source and binary forms, with or without
7 // modification, are permitted provided that the following conditions are
8 // met:
9 //
10 //     * Redistributions of source code must retain the above copyright
11 // notice, this list of conditions and the following disclaimer.
12 //     * Redistributions in binary form must reproduce the above
13 // copyright notice, this list of conditions and the following disclaimer
14 // in the documentation and/or other materials provided with the
15 // distribution.
16 //     * Neither the name of Google Inc. nor the names of its
17 // contributors may be used to endorse or promote products derived from
18 // this software without specific prior written permission.
19 //
20 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
21 // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
22 // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
23 // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
24 // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
25 // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
26 // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
27 // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
28 // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
29 // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30 // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31
32 package proto
33
34
35 /*
36  * Types and routines for supporting protocol buffer extensions.
37  */
38
39 import (
40         "os"
41         "reflect"
42         "strconv"
43         "unsafe"
44 )
45
46 // ExtensionRange represents a range of message extensions for a protocol buffer.
47 // Used in code generated by the protocol compiler.
48 type ExtensionRange struct {
49         Start, End int32 // both inclusive
50 }
51
52 // extendableProto is an interface implemented by any protocol buffer that may be extended.
53 type extendableProto interface {
54         ExtensionRangeArray() []ExtensionRange
55         ExtensionMap() map[int32]Extension
56 }
57
58 // ExtensionDesc represents an extension specification.
59 // Used in generated code from the protocol compiler.
60 type ExtensionDesc struct {
61         ExtendedType  interface{} // nil pointer to the type that is being extended
62         ExtensionType interface{} // nil pointer to the extension type
63         Field         int32       // field number
64         Name          string      // fully-qualified name of extension
65         Tag           string      // protobuf tag style
66 }
67
68 /*
69 Extension represents an extension in a message.
70
71 When an extension is stored in a message using SetExtension
72 only desc and value are set. When the message is marshaled
73 enc will be set to the encoded form of the message.
74
75 When a message is unmarshaled and contains extensions, each
76 extension will have only enc set. When such an extension is
77 accessed using GetExtension (or GetExtensions) desc and value
78 will be set.
79 */
80 type Extension struct {
81         desc  *ExtensionDesc
82         value interface{}
83         enc   []byte
84 }
85
86 // SetRawExtension is for testing only.
87 func SetRawExtension(base extendableProto, id int32, b []byte) {
88         base.ExtensionMap()[id] = Extension{enc: b}
89 }
90
91 // isExtensionField returns true iff the given field number is in an extension range.
92 func isExtensionField(pb extendableProto, field int32) bool {
93         for _, er := range pb.ExtensionRangeArray() {
94                 if er.Start <= field && field <= er.End {
95                         return true
96                 }
97         }
98         return false
99 }
100
101 // checkExtensionTypes checks that the given extension is valid for pb.
102 func checkExtensionTypes(pb extendableProto, extension *ExtensionDesc) os.Error {
103         // Check the extended type.
104         if a, b := reflect.TypeOf(pb), reflect.TypeOf(extension.ExtendedType); a != b {
105                 return os.NewError("bad extended type; " + b.String() + " does not extend " + a.String())
106         }
107         // Check the range.
108         if !isExtensionField(pb, extension.Field) {
109                 return os.NewError("bad extension number; not in declared ranges")
110         }
111         return nil
112 }
113
114 // encodeExtensionMap encodes any unmarshaled (unencoded) extensions in m.
115 func encodeExtensionMap(m map[int32]Extension) os.Error {
116         for k, e := range m {
117                 if e.value == nil || e.desc == nil {
118                         // Extension is only in its encoded form.
119                         continue
120                 }
121
122                 // We don't skip extensions that have an encoded form set,
123                 // because the extension value may have been mutated after
124                 // the last time this function was called.
125
126                 et := reflect.TypeOf(e.desc.ExtensionType)
127                 props := new(Properties)
128                 props.Init(et, "unknown_name", e.desc.Tag, 0)
129
130                 p := NewBuffer(nil)
131                 // The encoder must be passed a pointer to e.value.
132                 // Allocate a copy of value so that we can use its address.
133                 x := reflect.New(et)
134                 x.Elem().Set(reflect.ValueOf(e.value))
135                 if err := props.enc(p, props, x.Pointer()); err != nil {
136                         return err
137                 }
138                 e.enc = p.buf
139                 m[k] = e
140         }
141         return nil
142 }
143
144 // HasExtension returns whether the given extension is present in pb.
145 func HasExtension(pb extendableProto, extension *ExtensionDesc) bool {
146         // TODO: Check types, field numbers, etc.?
147         _, ok := pb.ExtensionMap()[extension.Field]
148         return ok
149 }
150
151 // ClearExtension removes the given extension from pb.
152 func ClearExtension(pb extendableProto, extension *ExtensionDesc) {
153         // TODO: Check types, field numbers, etc.?
154         pb.ExtensionMap()[extension.Field] = Extension{}, false
155 }
156
157 // GetExtension parses and returns the given extension of pb.
158 // If the extension is not present it returns (nil, nil).
159 func GetExtension(pb extendableProto, extension *ExtensionDesc) (interface{}, os.Error) {
160         if err := checkExtensionTypes(pb, extension); err != nil {
161                 return nil, err
162         }
163
164         e, ok := pb.ExtensionMap()[extension.Field]
165         if !ok {
166                 return nil, nil // not an error
167         }
168         if e.value != nil {
169                 // Already decoded. Check the descriptor, though.
170                 if e.desc != extension {
171                         // This shouldn't happen. If it does, it means that
172                         // GetExtension was called twice with two different
173                         // descriptors with the same field number.
174                         return nil, os.NewError("proto: descriptor conflict")
175                 }
176                 return e.value, nil
177         }
178
179         v, err := decodeExtension(e.enc, extension)
180         if err != nil {
181                 return nil, err
182         }
183
184         // Remember the decoded version and drop the encoded version.
185         // That way it is safe to mutate what we return.
186         e.value = v
187         e.desc = extension
188         e.enc = nil
189         return e.value, nil
190 }
191
192 // decodeExtension decodes an extension encoded in b.
193 func decodeExtension(b []byte, extension *ExtensionDesc) (interface{}, os.Error) {
194         // Discard wire type and field number varint. It isn't needed.
195         _, n := DecodeVarint(b)
196         o := NewBuffer(b[n:])
197
198         t := reflect.TypeOf(extension.ExtensionType)
199         props := &Properties{}
200         props.Init(t, "irrelevant_name", extension.Tag, 0)
201
202         base := unsafe.New(t)
203         var sbase uintptr
204         if t.Elem().Kind() == reflect.Struct {
205                 // props.dec will be dec_struct_message, which does not refer to sbase.
206                 *(*unsafe.Pointer)(base) = unsafe.New(t.Elem())
207         } else {
208                 sbase = uintptr(unsafe.New(t.Elem()))
209         }
210         if err := props.dec(o, props, uintptr(base), sbase); err != nil {
211                 return nil, err
212         }
213         return unsafe.Unreflect(t, base), nil
214 }
215
216 // GetExtensions returns a slice of the extensions present in pb that are also listed in es.
217 // The returned slice has the same length as es; missing extensions will appear as nil elements.
218 func GetExtensions(pb interface{}, es []*ExtensionDesc) (extensions []interface{}, err os.Error) {
219         epb, ok := pb.(extendableProto)
220         if !ok {
221                 err = os.NewError("not an extendable proto")
222                 return
223         }
224         extensions = make([]interface{}, len(es))
225         for i, e := range es {
226                 extensions[i], err = GetExtension(epb, e)
227                 if err != nil {
228                         return
229                 }
230         }
231         return
232 }
233
234 // TODO: (needed for repeated extensions)
235 //   - ExtensionSize
236 //   - AddExtension
237
238 // SetExtension sets the specified extension of pb to the specified value.
239 func SetExtension(pb extendableProto, extension *ExtensionDesc, value interface{}) os.Error {
240         if err := checkExtensionTypes(pb, extension); err != nil {
241                 return err
242         }
243         typ := reflect.TypeOf(extension.ExtensionType)
244         if typ != reflect.TypeOf(value) {
245                 return os.NewError("bad extension value type")
246         }
247
248         pb.ExtensionMap()[extension.Field] = Extension{desc: extension, value: value}
249         return nil
250 }
251
252 // A global registry of extensions.
253 // The generated code will register the generated descriptors by calling RegisterExtension.
254
255 var extensionMaps = make(map[reflect.Type]map[int32]*ExtensionDesc)
256
257 // RegisterExtension is called from the generated code.
258 func RegisterExtension(desc *ExtensionDesc) {
259         st := reflect.TypeOf(desc.ExtendedType).Elem()
260         m := extensionMaps[st]
261         if m == nil {
262                 m = make(map[int32]*ExtensionDesc)
263                 extensionMaps[st] = m
264         }
265         if _, ok := m[desc.Field]; ok {
266                 panic("proto: duplicate extension registered: " + st.String() + " " + strconv.Itoa(int(desc.Field)))
267         }
268         m[desc.Field] = desc
269 }
270
271 // RegisteredExtensions returns a map of the registered extensions of a
272 // protocol buffer struct, indexed by the extension number.
273 // The argument pb should be a nil pointer to the struct type.
274 func RegisteredExtensions(pb interface{}) map[int32]*ExtensionDesc {
275         return extensionMaps[reflect.TypeOf(pb).Elem()]
276 }