From 71bb209a43deb8942c0aca0c5272d02fbf573e67 Mon Sep 17 00:00:00 2001
From: thopri <thopri@noc.ac.uk>
Date: Wed, 18 Mar 2020 14:38:08 +0000
Subject: [PATCH] added salinity check

---
 pynemo/profile.py       |  3 ++-
 unit_tests/gen_tools.py | 12 ++++++------
 unit_tests/test_gen.py  |  4 +++-
 unit_tests/unit_test.py | 20 ++++++++++----------
 4 files changed, 21 insertions(+), 18 deletions(-)

diff --git a/pynemo/profile.py b/pynemo/profile.py
index 7931c68..9110828 100644
--- a/pynemo/profile.py
+++ b/pynemo/profile.py
@@ -505,7 +505,8 @@ def process_bdy(setup_filepath=0, mask_gui=False):
         if settings['use_cmems'] == True:
             logger.info('using CMEMS variable names......')
             if ln_tra:
-                var_in['t'].extend(['thetao'])  # ,'so'])
+                var_in['t'].extend(['thetao'])
+                var_in['t'].extend(['so'])
 
             if ln_dyn2d or ln_dyn3d:
                 var_in['u'].extend(['uo'])
diff --git a/unit_tests/gen_tools.py b/unit_tests/gen_tools.py
index b7ed9b9..6dddbf9 100644
--- a/unit_tests/gen_tools.py
+++ b/unit_tests/gen_tools.py
@@ -557,7 +557,6 @@ def write_parameter(fileout, grid_h,grid_z,params):
     dataset.createDimension('z', nz)
     dataset.createDimension('time', nt)
 
-
     # Create Variables
     longitude = dataset.createVariable('longitude', np.float32, ('y', 'x'))
     latitude = dataset.createVariable('latitude', np.float32, ('y', 'x'))
@@ -577,11 +576,12 @@ def write_parameter(fileout, grid_h,grid_z,params):
     latitude[:, :] = grid_h['latt'].T
     depth[:] = grid_z['dept_1d']
     time_counter[:] = np.linspace(587340.00,588060.00,31)
-    parameter = dataset.createVariable(str(params['name']), np.float64, ('time','z', 'y', 'x'))
-    parameter.units, parameter.long_name = str(params['units']), str(params['longname'])
-    value_fill = np.ones(np.shape(grid_z['e3t']))
-    value_fill = value_fill*params['const_value']
-    parameter[:, :, :] = value_fill.T
+    for key in params:
+        parameter = dataset.createVariable(str(params[key]['name']), np.float64, ('time','z', 'y', 'x'))
+        parameter.units, parameter.long_name = str(params[key]['units']), str(params[key]['longname'])
+        value_fill = np.ones(np.shape(grid_z['e3t']))
+        value_fill = value_fill*params[key]['const_value']
+        parameter[:, :, :] = value_fill.T
 
     # Close off pointer
     dataset.close()
diff --git a/unit_tests/test_gen.py b/unit_tests/test_gen.py
index 8e8cd19..f7825cf 100644
--- a/unit_tests/test_gen.py
+++ b/unit_tests/test_gen.py
@@ -105,7 +105,9 @@ def _main():
 
     # write boundary files (constant parameters)
     out_fname = 'unit_tests/test_data/output_boundary_T.nc'
-    params = {'name':'thetao','const_value':15.0,'longname':'temperature','units':'degreesC'}
+    params = {'param1': {'name':'thetao','const_value':15.0,'longname':'temperature','units':'degreesC'},
+              'param2': {'name':'so','const_value':35.0,'longname':'salinity','units':'PSU'}
+              }
     boundary = gt.write_parameter(out_fname,grid_h1,grid_z1,params)
     if boundary == 0:
         print('Success!')
diff --git a/unit_tests/unit_test.py b/unit_tests/unit_test.py
index c5d6784..7757c48 100644
--- a/unit_tests/unit_test.py
+++ b/unit_tests/unit_test.py
@@ -33,16 +33,16 @@ def test_temp():
         assert abs(temp_[temp_ != 0.0].max() - 15) <= 0.001
         assert abs(temp_[temp_ != 0.0].min() - 15) <= 0.001
 
-#def test_salinty():
-#    test_files = glob.glob('unit_tests/test_outputs/unit_test*')
-#    for t in test_files:
-#        results = Dataset(t)  # open results
-#        sal = results['so'][:]
-#        results.close()
-#        sal_ = np.ma.masked_array(sal,sal == -32767.0)
-#        assert abs(sal_[sal_!=0.0].mean() - 35) <= 0.001
-#        assert abs(sal_[sal_ != 0.0].max() - 35) <= 0.001
-#        assert abs(sal_[sal_ != 0.0].min() - 35) <= 0.001
+def test_salinty():
+    test_files = glob.glob('unit_tests/test_outputs/unit_test*')
+    for t in test_files:
+        results = Dataset(t)  # open results
+        sal = results['so'][:]
+        results.close()
+        sal_ = np.ma.masked_array(sal,sal == -32767.0)
+        assert abs(sal_[sal_!=0.0].mean() - 35) <= 0.001
+        assert abs(sal_[sal_ != 0.0].max() - 35) <= 0.001
+        assert abs(sal_[sal_ != 0.0].min() - 35) <= 0.001
 
 # clean up test I/O
 files = glob.glob('unit_tests/test_outputs/*')
-- 
GitLab